MilesCranmer commited on
Commit
5a01e6f
1 Parent(s): 780b3a0

Move numpy export code to separate file

Browse files
Files changed (2) hide show
  1. pysr/export_numpy.py +29 -0
  2. pysr/sr.py +3 -27
pysr/export_numpy.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Code for exporting discovered expressions to numpy"""
2
+ import numpy as np
3
+ import pandas as pd
4
+ from sympy import lambdify
5
+
6
+
7
+ class CallableEquation:
8
+ """Simple wrapper for numpy lambda functions built with sympy"""
9
+
10
+ def __init__(self, sympy_symbols, eqn, selection=None, variable_names=None):
11
+ self._sympy = eqn
12
+ self._sympy_symbols = sympy_symbols
13
+ self._selection = selection
14
+ self._variable_names = variable_names
15
+ self._lambda = lambdify(sympy_symbols, eqn)
16
+
17
+ def __repr__(self):
18
+ return f"PySRFunction(X=>{self._sympy})"
19
+
20
+ def __call__(self, X):
21
+ expected_shape = (X.shape[0],)
22
+ if isinstance(X, pd.DataFrame):
23
+ # Lambda function takes as argument:
24
+ return self._lambda(
25
+ **{k: X[k].values for k in self._variable_names}
26
+ ) * np.ones(expected_shape)
27
+ if self._selection is not None:
28
+ X = X[:, self._selection]
29
+ return self._lambda(*X.T) * np.ones(expected_shape)
pysr/sr.py CHANGED
@@ -3,7 +3,7 @@ import sys
3
  import numpy as np
4
  import pandas as pd
5
  import sympy
6
- from sympy import sympify, lambdify
7
  import re
8
  import tempfile
9
  import shutil
@@ -22,6 +22,7 @@ from .julia_helpers import (
22
  _add_sr_to_julia_project,
23
  import_error_string,
24
  )
 
25
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
26
 
27
 
@@ -169,35 +170,10 @@ def best_callable(*args, **kwargs): # pragma: no cover
169
  )
170
 
171
 
172
- class CallableEquation:
173
- """Simple wrapper for numpy lambda functions built with sympy"""
174
-
175
- def __init__(self, sympy_symbols, eqn, selection=None, variable_names=None):
176
- self._sympy = eqn
177
- self._sympy_symbols = sympy_symbols
178
- self._selection = selection
179
- self._variable_names = variable_names
180
- self._lambda = lambdify(sympy_symbols, eqn)
181
-
182
- def __repr__(self):
183
- return f"PySRFunction(X=>{self._sympy})"
184
-
185
- def __call__(self, X):
186
- expected_shape = (X.shape[0],)
187
- if isinstance(X, pd.DataFrame):
188
- # Lambda function takes as argument:
189
- return self._lambda(
190
- **{k: X[k].values for k in self._variable_names}
191
- ) * np.ones(expected_shape)
192
- if self._selection is not None:
193
- X = X[:, self._selection]
194
- return self._lambda(*X.T) * np.ones(expected_shape)
195
-
196
-
197
  class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
198
  """
199
  High-performance symbolic regression.
200
-
201
  This is the scikit-learn interface for SymbolicRegression.jl.
202
  This model will automatically search for equations which fit
203
  a given dataset subject to a particular loss and set of
 
3
  import numpy as np
4
  import pandas as pd
5
  import sympy
6
+ from sympy import sympify
7
  import re
8
  import tempfile
9
  import shutil
 
22
  _add_sr_to_julia_project,
23
  import_error_string,
24
  )
25
+ from .export_numpy import CallableEquation
26
  from .deprecated import make_deprecated_kwargs_for_pysr_regressor
27
 
28
 
 
170
  )
171
 
172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
174
  """
175
  High-performance symbolic regression.
176
+
177
  This is the scikit-learn interface for SymbolicRegression.jl.
178
  This model will automatically search for equations which fit
179
  a given dataset subject to a particular loss and set of