MilesCranmer commited on
Commit
45b290b
β€’
2 Parent(s): 6a4fa2c 4db1c62

Merge pull request #48 from MilesCranmer/torch-export

Browse files
.github/workflows/CI.yml CHANGED
@@ -61,17 +61,23 @@ jobs:
61
  python setup.py install
62
  - name: "Install Coverage tool"
63
  run: pip install coverage coveralls
 
 
 
64
  - name: "Install JAX"
65
  if: matrix.os != 'windows-latest'
66
  run: pip install jax jaxlib # (optional import)
67
  shell: bash
68
- - name: "Run tests"
69
- run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
70
- shell: bash
71
  - name: "Run JAX tests"
72
  if: matrix.os != 'windows-latest'
73
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
74
  shell: bash
 
 
 
 
 
 
75
  - name: Coveralls
76
  env:
77
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 
61
  python setup.py install
62
  - name: "Install Coverage tool"
63
  run: pip install coverage coveralls
64
+ - name: "Run tests"
65
+ run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
66
+ shell: bash
67
  - name: "Install JAX"
68
  if: matrix.os != 'windows-latest'
69
  run: pip install jax jaxlib # (optional import)
70
  shell: bash
 
 
 
71
  - name: "Run JAX tests"
72
  if: matrix.os != 'windows-latest'
73
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
74
  shell: bash
75
+ - name: "Install Torch"
76
+ run: pip install torch # (optional import)
77
+ shell: bash
78
+ - name: "Run Torch tests"
79
+ run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
80
+ shell: bash
81
  - name: Coveralls
82
  env:
83
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
pysr/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
- from .export import sympy2jax
 
 
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
+ from .export_jax import sympy2jax
4
+ from .export_torch import sympy2torch
pysr/{export.py β†’ export_jax.py} RENAMED
@@ -3,59 +3,53 @@ import sympy
3
  import string
4
  import random
5
 
6
- try:
7
- import jax
8
- from jax import numpy as jnp
9
- from jax.scipy import special as jsp
10
-
11
  # Special since need to reduce arguments.
12
- MUL = 0
13
- ADD = 1
14
-
15
- _jnp_func_lookup = {
16
- sympy.Mul: MUL,
17
- sympy.Add: ADD,
18
- sympy.div: "jnp.div",
19
- sympy.Abs: "jnp.abs",
20
- sympy.sign: "jnp.sign",
21
- # Note: May raise error for ints.
22
- sympy.ceiling: "jnp.ceil",
23
- sympy.floor: "jnp.floor",
24
- sympy.log: "jnp.log",
25
- sympy.exp: "jnp.exp",
26
- sympy.sqrt: "jnp.sqrt",
27
- sympy.cos: "jnp.cos",
28
- sympy.acos: "jnp.acos",
29
- sympy.sin: "jnp.sin",
30
- sympy.asin: "jnp.asin",
31
- sympy.tan: "jnp.tan",
32
- sympy.atan: "jnp.atan",
33
- sympy.atan2: "jnp.atan2",
34
- # Note: Also may give NaN for complex results.
35
- sympy.cosh: "jnp.cosh",
36
- sympy.acosh: "jnp.acosh",
37
- sympy.sinh: "jnp.sinh",
38
- sympy.asinh: "jnp.asinh",
39
- sympy.tanh: "jnp.tanh",
40
- sympy.atanh: "jnp.atanh",
41
- sympy.Pow: "jnp.power",
42
- sympy.re: "jnp.real",
43
- sympy.im: "jnp.imag",
44
- sympy.arg: "jnp.angle",
45
- # Note: May raise error for ints and complexes
46
- sympy.erf: "jsp.erf",
47
- sympy.erfc: "jsp.erfc",
48
- sympy.LessThan: "jnp.less",
49
- sympy.GreaterThan: "jnp.greater",
50
- sympy.And: "jnp.logical_and",
51
- sympy.Or: "jnp.logical_or",
52
- sympy.Not: "jnp.logical_not",
53
- sympy.Max: "jnp.max",
54
- sympy.Min: "jnp.min",
55
- sympy.Mod: "jnp.mod",
56
- }
57
- except ImportError:
58
- ...
59
 
60
  def sympy2jaxtext(expr, parameters, symbols_in):
61
  if issubclass(expr.func, sympy.Float):
@@ -75,7 +69,28 @@ def sympy2jaxtext(expr, parameters, symbols_in):
75
  else:
76
  return f'{_func}({", ".join(args)})'
77
 
78
- def sympy2jax(equation, symbols_in):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  """Returns a function f and its parameters;
80
  the function takes an input matrix, and a list of arguments:
81
  f(X, parameters)
@@ -146,9 +161,15 @@ def sympy2jax(equation, symbols_in):
146
  # 3.5427954 , -2.7479894 ], dtype=float32)
147
  ```
148
  """
 
 
 
 
 
 
149
  parameters = []
150
- functional_form_text = sympy2jaxtext(equation, parameters, symbols_in)
151
- hash_string = 'A_' + str(abs(hash(str(equation) + str(symbols_in))))
152
  text = f"def {hash_string}(X, parameters):\n"
153
  text += " return "
154
  text += functional_form_text
 
3
  import string
4
  import random
5
 
 
 
 
 
 
6
  # Special since need to reduce arguments.
7
+ MUL = 0
8
+ ADD = 1
9
+
10
+ _jnp_func_lookup = {
11
+ sympy.Mul: MUL,
12
+ sympy.Add: ADD,
13
+ sympy.div: "jnp.div",
14
+ sympy.Abs: "jnp.abs",
15
+ sympy.sign: "jnp.sign",
16
+ # Note: May raise error for ints.
17
+ sympy.ceiling: "jnp.ceil",
18
+ sympy.floor: "jnp.floor",
19
+ sympy.log: "jnp.log",
20
+ sympy.exp: "jnp.exp",
21
+ sympy.sqrt: "jnp.sqrt",
22
+ sympy.cos: "jnp.cos",
23
+ sympy.acos: "jnp.acos",
24
+ sympy.sin: "jnp.sin",
25
+ sympy.asin: "jnp.asin",
26
+ sympy.tan: "jnp.tan",
27
+ sympy.atan: "jnp.atan",
28
+ sympy.atan2: "jnp.atan2",
29
+ # Note: Also may give NaN for complex results.
30
+ sympy.cosh: "jnp.cosh",
31
+ sympy.acosh: "jnp.acosh",
32
+ sympy.sinh: "jnp.sinh",
33
+ sympy.asinh: "jnp.asinh",
34
+ sympy.tanh: "jnp.tanh",
35
+ sympy.atanh: "jnp.atanh",
36
+ sympy.Pow: "jnp.power",
37
+ sympy.re: "jnp.real",
38
+ sympy.im: "jnp.imag",
39
+ sympy.arg: "jnp.angle",
40
+ # Note: May raise error for ints and complexes
41
+ sympy.erf: "jsp.erf",
42
+ sympy.erfc: "jsp.erfc",
43
+ sympy.LessThan: "jnp.less",
44
+ sympy.GreaterThan: "jnp.greater",
45
+ sympy.And: "jnp.logical_and",
46
+ sympy.Or: "jnp.logical_or",
47
+ sympy.Not: "jnp.logical_not",
48
+ sympy.Max: "jnp.max",
49
+ sympy.Min: "jnp.min",
50
+ sympy.Mod: "jnp.mod",
51
+ }
52
+
 
53
 
54
  def sympy2jaxtext(expr, parameters, symbols_in):
55
  if issubclass(expr.func, sympy.Float):
 
69
  else:
70
  return f'{_func}({", ".join(args)})'
71
 
72
+
73
+ jax_initialized = False
74
+ jax = None
75
+ jnp = None
76
+ jsp = None
77
+
78
+ def _initialize_jax():
79
+ global jax_initialized
80
+ global jax
81
+ global jnp
82
+ global jsp
83
+
84
+ if not jax_initialized:
85
+ import jax as _jax
86
+ from jax import numpy as _jnp
87
+ from jax.scipy import special as _jsp
88
+ jax = _jax
89
+ jnp = _jnp
90
+ jsp = _jsp
91
+
92
+
93
+ def sympy2jax(expression, symbols_in):
94
  """Returns a function f and its parameters;
95
  the function takes an input matrix, and a list of arguments:
96
  f(X, parameters)
 
161
  # 3.5427954 , -2.7479894 ], dtype=float32)
162
  ```
163
  """
164
+ _initialize_jax()
165
+ global jax_initialized
166
+ global jax
167
+ global jnp
168
+ global jsp
169
+
170
  parameters = []
171
+ functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
172
+ hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
173
  text = f"def {hash_string}(X, parameters):\n"
174
  text += " return "
175
  text += functional_form_text
pysr/export_torch.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####
2
+ # From https://github.com/patrick-kidger/sympytorch
3
+ # Copied here to allow PySR-specific tweaks
4
+ #####
5
+
6
+ import collections as co
7
+ import functools as ft
8
+ import sympy
9
+
10
+ def _reduce(fn):
11
+ def fn_(*args):
12
+ return ft.reduce(fn, args)
13
+ return fn_
14
+
15
+ torch_initialized = False
16
+ torch = None
17
+ _global_func_lookup = None
18
+ _Node = None
19
+ SingleSymPyModule = None
20
+
21
+ def _initialize_torch():
22
+ global torch_initialized
23
+ global torch
24
+ global _global_func_lookup
25
+ global _Node
26
+ global SingleSymPyModule
27
+
28
+ # Way to lazy load torch, only if this is called,
29
+ # but still allow this module to be loaded in __init__
30
+ if not torch_initialized:
31
+ import torch as _torch
32
+ torch = _torch
33
+
34
+ _global_func_lookup = {
35
+ sympy.Mul: _reduce(torch.mul),
36
+ sympy.Add: _reduce(torch.add),
37
+ sympy.div: torch.div,
38
+ sympy.Abs: torch.abs,
39
+ sympy.sign: torch.sign,
40
+ # Note: May raise error for ints.
41
+ sympy.ceiling: torch.ceil,
42
+ sympy.floor: torch.floor,
43
+ sympy.log: torch.log,
44
+ sympy.exp: torch.exp,
45
+ sympy.sqrt: torch.sqrt,
46
+ sympy.cos: torch.cos,
47
+ sympy.acos: torch.acos,
48
+ sympy.sin: torch.sin,
49
+ sympy.asin: torch.asin,
50
+ sympy.tan: torch.tan,
51
+ sympy.atan: torch.atan,
52
+ sympy.atan2: torch.atan2,
53
+ # Note: May give NaN for complex results.
54
+ sympy.cosh: torch.cosh,
55
+ sympy.acosh: torch.acosh,
56
+ sympy.sinh: torch.sinh,
57
+ sympy.asinh: torch.asinh,
58
+ sympy.tanh: torch.tanh,
59
+ sympy.atanh: torch.atanh,
60
+ sympy.Pow: torch.pow,
61
+ sympy.re: torch.real,
62
+ sympy.im: torch.imag,
63
+ sympy.arg: torch.angle,
64
+ # Note: May raise error for ints and complexes
65
+ sympy.erf: torch.erf,
66
+ sympy.loggamma: torch.lgamma,
67
+ sympy.Eq: torch.eq,
68
+ sympy.Ne: torch.ne,
69
+ sympy.StrictGreaterThan: torch.gt,
70
+ sympy.StrictLessThan: torch.lt,
71
+ sympy.LessThan: torch.le,
72
+ sympy.GreaterThan: torch.ge,
73
+ sympy.And: torch.logical_and,
74
+ sympy.Or: torch.logical_or,
75
+ sympy.Not: torch.logical_not,
76
+ sympy.Max: torch.max,
77
+ sympy.Min: torch.min,
78
+ # Matrices
79
+ sympy.MatAdd: torch.add,
80
+ sympy.HadamardProduct: torch.mul,
81
+ sympy.Trace: torch.trace,
82
+ # Note: May raise error for integer matrices.
83
+ sympy.Determinant: torch.det,
84
+ }
85
+
86
+ class _Node(torch.nn.Module):
87
+ """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
88
+ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
89
+ super().__init__(**kwargs)
90
+
91
+ self._sympy_func = expr.func
92
+
93
+ if issubclass(expr.func, sympy.Float):
94
+ self._value = torch.nn.Parameter(torch.tensor(float(expr)))
95
+ self._torch_func = lambda: self._value
96
+ self._args = ()
97
+ elif issubclass(expr.func, sympy.UnevaluatedExpr):
98
+ if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
99
+ raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
100
+ self.register_buffer('_value', torch.tensor(float(expr.args[0])))
101
+ self._torch_func = lambda: self._value
102
+ self._args = ()
103
+ elif issubclass(expr.func, sympy.Integer):
104
+ # Can get here if expr is one of the Integer special cases,
105
+ # e.g. NegativeOne
106
+ self._value = int(expr)
107
+ self._torch_func = lambda: self._value
108
+ self._args = ()
109
+ elif issubclass(expr.func, sympy.Symbol):
110
+ self._name = expr.name
111
+ self._torch_func = lambda value: value
112
+ self._args = ((lambda memodict: memodict[expr.name]),)
113
+ else:
114
+ self._torch_func = _func_lookup[expr.func]
115
+ args = []
116
+ for arg in expr.args:
117
+ try:
118
+ arg_ = _memodict[arg]
119
+ except KeyError:
120
+ arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
121
+ _memodict[arg] = arg_
122
+ args.append(arg_)
123
+ self._args = torch.nn.ModuleList(args)
124
+
125
+ def forward(self, memodict):
126
+ args = []
127
+ for arg in self._args:
128
+ try:
129
+ arg_ = memodict[arg]
130
+ except KeyError:
131
+ arg_ = arg(memodict)
132
+ memodict[arg] = arg_
133
+ args.append(arg_)
134
+ return self._torch_func(*args)
135
+
136
+
137
+ class SingleSymPyModule(torch.nn.Module):
138
+ """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
139
+ def __init__(self, expression, symbols_in,
140
+ extra_funcs=None, **kwargs):
141
+ super().__init__(**kwargs)
142
+
143
+ if extra_funcs is None:
144
+ extra_funcs = {}
145
+ _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
146
+
147
+ _memodict = {}
148
+ self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
149
+ self._expression_string = str(expression)
150
+ self.symbols_in = [str(symbol) for symbol in symbols_in]
151
+
152
+ def __repr__(self):
153
+ return f"{type(self).__name__}(expression={self._expression_string})"
154
+
155
+ def forward(self, X):
156
+ symbols = {symbol: X[:, i]
157
+ for i, symbol in enumerate(self.symbols_in)}
158
+ return self._node(symbols)
159
+
160
+
161
+ def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
162
+ """Returns a module for a given sympy expression with trainable parameters;
163
+
164
+ This function will assume the input to the module is a matrix X, where
165
+ each column corresponds to each symbol you pass in `symbols_in`.
166
+ """
167
+ global SingleSymPyModule
168
+
169
+ _initialize_torch()
170
+
171
+ return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)
pysr/sr.py CHANGED
@@ -13,7 +13,6 @@ import shutil
13
  from pathlib import Path
14
  from datetime import datetime
15
  import warnings
16
- from .export import sympy2jax
17
 
18
  global_equation_file = 'hall_of_fame.csv'
19
  global_n_features = None
@@ -103,6 +102,8 @@ def pysr(X, y, weights=None,
103
  perturbationFactor=1.0,
104
  timeout=None,
105
  extra_sympy_mappings=None,
 
 
106
  equation_file=None,
107
  verbosity=1e9,
108
  progress=True,
@@ -124,11 +125,12 @@ def pysr(X, y, weights=None,
124
  update=True,
125
  temp_equation_file=False,
126
  output_jax_format=False,
 
127
  optimizer_algorithm="BFGS",
128
  optimizer_nrestarts=3,
129
  optimize_probability=1.0,
130
- optimizer_iterations=10,
131
- ):
132
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
133
  Note: most default parameters have been tuned over several example
134
  equations, but you should adjust `niterations`,
@@ -241,6 +243,8 @@ def pysr(X, y, weights=None,
241
  delete_tempfiles argument.
242
  :param output_jax_format: Whether to create a 'jax_format' column in the output,
243
  containing jax-callable functions and the default parameters in a jax array.
 
 
244
  :returns: pd.DataFrame or list, Results dataframe,
245
  giving complexity, MSE, and equations (as strings), as well as functional
246
  forms. If list, each element corresponds to a dataframe of equations
@@ -334,8 +338,11 @@ def pysr(X, y, weights=None,
334
  weightSimplify=weightSimplify,
335
  constraints=constraints,
336
  extra_sympy_mappings=extra_sympy_mappings,
 
 
337
  julia_project=julia_project, loss=loss,
338
  output_jax_format=output_jax_format,
 
339
  multioutput=multioutput, nout=nout)
340
 
341
  kwargs = {**_set_paths(tempdir), **kwargs}
@@ -715,10 +722,10 @@ def run_feature_selection(X, y, select_k_features):
715
  the k most important features in X, returning indices for those
716
  features as output."""
717
 
718
- from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
719
  from sklearn.feature_selection import SelectFromModel, SelectKBest
720
 
721
- clf = GradientBoostingRegressor(n_estimators=100, learning_rate=0.1, max_depth=1, random_state=0, loss='ls') #RandomForestRegressor()
722
  clf.fit(X, y)
723
  selector = SelectFromModel(clf, threshold=-np.inf,
724
  max_features=select_k_features, prefit=True)
@@ -726,6 +733,8 @@ def run_feature_selection(X, y, select_k_features):
726
 
727
  def get_hof(equation_file=None, n_features=None, variable_names=None,
728
  extra_sympy_mappings=None, output_jax_format=False,
 
 
729
  multioutput=None, nout=None, **kwargs):
730
  """Get the equations from a hall of fame file. If no arguments
731
  entered, the ones used previously from a call to PySR will be used."""
@@ -770,6 +779,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
770
  lambda_format = []
771
  if output_jax_format:
772
  jax_format = []
 
 
773
  use_custom_variable_names = (len(variable_names) != 0)
774
  local_sympy_mappings = {
775
  **extra_sympy_mappings,
@@ -784,11 +795,22 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
784
  for i in range(len(output)):
785
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
786
  sympy_format.append(eqn)
 
 
 
 
 
787
  if output_jax_format:
 
788
  func, params = sympy2jax(eqn, sympy_symbols)
789
  jax_format.append({'callable': func, 'parameters': params})
790
 
791
- lambda_format.append(CallableEquation(sympy_symbols, eqn))
 
 
 
 
 
792
  curMSE = output.loc[i, 'MSE']
793
  curComplexity = output.loc[i, 'Complexity']
794
 
@@ -808,6 +830,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
808
  if output_jax_format:
809
  output_cols += ['jax_format']
810
  output['jax_format'] = jax_format
 
 
 
811
 
812
  ret_outputs.append(output[output_cols])
813
 
 
13
  from pathlib import Path
14
  from datetime import datetime
15
  import warnings
 
16
 
17
  global_equation_file = 'hall_of_fame.csv'
18
  global_n_features = None
 
102
  perturbationFactor=1.0,
103
  timeout=None,
104
  extra_sympy_mappings=None,
105
+ extra_torch_mappings=None,
106
+ extra_jax_mappings=None,
107
  equation_file=None,
108
  verbosity=1e9,
109
  progress=True,
 
125
  update=True,
126
  temp_equation_file=False,
127
  output_jax_format=False,
128
+ output_torch_format=False,
129
  optimizer_algorithm="BFGS",
130
  optimizer_nrestarts=3,
131
  optimize_probability=1.0,
132
+ optimizer_iterations=10
133
+ ):
134
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
135
  Note: most default parameters have been tuned over several example
136
  equations, but you should adjust `niterations`,
 
243
  delete_tempfiles argument.
244
  :param output_jax_format: Whether to create a 'jax_format' column in the output,
245
  containing jax-callable functions and the default parameters in a jax array.
246
+ :param output_torch_format: Whether to create a 'torch_format' column in the output,
247
+ containing a torch module with trainable parameters.
248
  :returns: pd.DataFrame or list, Results dataframe,
249
  giving complexity, MSE, and equations (as strings), as well as functional
250
  forms. If list, each element corresponds to a dataframe of equations
 
338
  weightSimplify=weightSimplify,
339
  constraints=constraints,
340
  extra_sympy_mappings=extra_sympy_mappings,
341
+ extra_jax_mappings=extra_jax_mappings,
342
+ extra_torch_mappings=extra_torch_mappings,
343
  julia_project=julia_project, loss=loss,
344
  output_jax_format=output_jax_format,
345
+ output_torch_format=output_torch_format,
346
  multioutput=multioutput, nout=nout)
347
 
348
  kwargs = {**_set_paths(tempdir), **kwargs}
 
722
  the k most important features in X, returning indices for those
723
  features as output."""
724
 
725
+ from sklearn.ensemble import RandomForestRegressor
726
  from sklearn.feature_selection import SelectFromModel, SelectKBest
727
 
728
+ clf = RandomForestRegressor(n_estimators=100, max_depth=3, random_state=0)
729
  clf.fit(X, y)
730
  selector = SelectFromModel(clf, threshold=-np.inf,
731
  max_features=select_k_features, prefit=True)
 
733
 
734
  def get_hof(equation_file=None, n_features=None, variable_names=None,
735
  extra_sympy_mappings=None, output_jax_format=False,
736
+ output_torch_format=False,
737
+ extra_jax_mappings=None, extra_torch_mappings=None,
738
  multioutput=None, nout=None, **kwargs):
739
  """Get the equations from a hall of fame file. If no arguments
740
  entered, the ones used previously from a call to PySR will be used."""
 
779
  lambda_format = []
780
  if output_jax_format:
781
  jax_format = []
782
+ if output_torch_format:
783
+ torch_format = []
784
  use_custom_variable_names = (len(variable_names) != 0)
785
  local_sympy_mappings = {
786
  **extra_sympy_mappings,
 
795
  for i in range(len(output)):
796
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
797
  sympy_format.append(eqn)
798
+
799
+ # Numpy:
800
+ lambda_format.append(CallableEquation(sympy_symbols, eqn))
801
+
802
+ # JAX:
803
  if output_jax_format:
804
+ from .export_jax import sympy2jax
805
  func, params = sympy2jax(eqn, sympy_symbols)
806
  jax_format.append({'callable': func, 'parameters': params})
807
 
808
+ # Torch:
809
+ if output_torch_format:
810
+ from .export_torch import sympy2torch
811
+ module = sympy2torch(eqn, sympy_symbols)
812
+ torch_format.append(module)
813
+
814
  curMSE = output.loc[i, 'MSE']
815
  curComplexity = output.loc[i, 'Complexity']
816
 
 
830
  if output_jax_format:
831
  output_cols += ['jax_format']
832
  output['jax_format'] = jax_format
833
+ if output_torch_format:
834
+ output_cols += ['torch_format']
835
+ output['torch_format'] = torch_format
836
 
837
  ret_outputs.append(output[output_cols])
838
 
test/test_jax.py CHANGED
@@ -1,6 +1,7 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import pysr, sympy2jax
 
4
  from jax import numpy as jnp
5
  from jax import random
6
  from jax import grad
@@ -15,3 +16,24 @@ class TestJAX(unittest.TestCase):
15
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
16
  f, params = sympy2jax(cosx, [x, y, z])
17
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import sympy2jax, get_hof
4
+ import pandas as pd
5
  from jax import numpy as jnp
6
  from jax import random
7
  from jax import grad
 
16
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
17
  f, params = sympy2jax(cosx, [x, y, z])
18
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
19
+ def test_pipeline(self):
20
+ X = np.random.randn(100, 2)
21
+ equations = pd.DataFrame({
22
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
23
+ 'MSE': [1.0, 0.1, 1e-5],
24
+ 'Complexity': [1, 2, 3]
25
+ })
26
+
27
+ equations['Complexity MSE Equation'.split(' ')].to_csv(
28
+ 'equation_file.csv.bkup', sep='|')
29
+
30
+ equations = get_hof(
31
+ 'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
32
+ extra_sympy_mappings={}, output_jax_format=True,
33
+ multioutput=False, nout=1)
34
+
35
+ jformat = equations.iloc[-1].jax_format
36
+ np.testing.assert_almost_equal(
37
+ np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
38
+ np.square(np.cos(X[:, 0]))
39
+ )
test/test_torch.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import numpy as np
3
+ import pandas as pd
4
+ from pysr import sympy2torch, get_hof
5
+ import torch
6
+ import sympy
7
+
8
+ class TestTorch(unittest.TestCase):
9
+ def test_sympy2torch(self):
10
+ x, y, z = sympy.symbols('x y z')
11
+ cosx = 1.0 * sympy.cos(x) + y
12
+ X = torch.randn((1000, 3))
13
+ true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
14
+ torch_module = sympy2torch(cosx, [x, y, z])
15
+ self.assertTrue(
16
+ np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
17
+ )
18
+ def test_pipeline(self):
19
+ X = np.random.randn(100, 2)
20
+ equations = pd.DataFrame({
21
+ 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
22
+ 'MSE': [1.0, 0.1, 1e-5],
23
+ 'Complexity': [1, 2, 3]
24
+ })
25
+
26
+ equations['Complexity MSE Equation'.split(' ')].to_csv(
27
+ 'equation_file.csv.bkup', sep='|')
28
+
29
+ equations = get_hof(
30
+ 'equation_file.csv', n_features=2, variables_names='x0 x1'.split(' '),
31
+ extra_sympy_mappings={}, output_torch_format=True,
32
+ multioutput=False, nout=1)
33
+
34
+ tformat = equations.iloc[-1].torch_format
35
+ np.testing.assert_almost_equal(
36
+ tformat(torch.tensor(X)).detach().numpy(),
37
+ np.square(np.cos(X[:, 0]))
38
+ )