MilesCranmer commited on
Commit
66b15fc
·
1 Parent(s): 205d866

Move everything into single PySRRegressor

Browse files
Files changed (2) hide show
  1. pysr/sklearn.py +0 -138
  2. pysr/sr.py +858 -757
pysr/sklearn.py DELETED
@@ -1,138 +0,0 @@
1
- from pysr import pysr, best_row, get_hof
2
- from sklearn.base import BaseEstimator, RegressorMixin
3
- import inspect
4
- import pandas as pd
5
-
6
-
7
- class PySRRegressor(BaseEstimator, RegressorMixin):
8
- def __init__(self, model_selection="accuracy", **params):
9
- """Initialize settings for pysr.pysr call.
10
-
11
- :param model_selection: How to select a model. Can be 'accuracy' or 'best'. 'best' will optimize a combination of complexity and accuracy.
12
- :type model_selection: str
13
- """
14
- super().__init__()
15
- self.model_selection = model_selection
16
- self.params = params
17
-
18
- # Stored equations:
19
- self.equations = None
20
-
21
- def __repr__(self):
22
- if self.equations is None:
23
- return "PySRRegressor.equations = None"
24
-
25
- equations = self.equations
26
- selected = ["" for _ in range(len(equations))]
27
- if self.model_selection == "accuracy":
28
- chosen_row = -1
29
- elif self.model_selection == "best":
30
- chosen_row = equations["score"].idxmax()
31
- else:
32
- raise NotImplementedError
33
- selected[chosen_row] = ">>>>"
34
- output = "PySRRegressor.equations = [\n"
35
- repr_equations = pd.DataFrame(
36
- dict(
37
- pick=selected,
38
- score=equations["score"],
39
- Equation=equations["Equation"],
40
- MSE=equations["MSE"],
41
- Complexity=equations["Complexity"],
42
- )
43
- )
44
- output += repr_equations.__repr__()
45
- output += "\n]"
46
- return output
47
-
48
- def set_params(self, **params):
49
- """Set parameters for pysr.pysr call or model_selection strategy."""
50
- for key, value in params.items():
51
- if key == "model_selection":
52
- self.model_selection = value
53
- self.params[key] = value
54
-
55
- return self
56
-
57
- def get_params(self, deep=True):
58
- del deep
59
- return {**self.params, "model_selection": self.model_selection}
60
-
61
- def get_best(self):
62
- if self.equations is None:
63
- return 0.0
64
- if self.model_selection == "accuracy":
65
- return self.equations.iloc[-1]
66
- elif self.model_selection == "best":
67
- return best_row(self.equations)
68
- else:
69
- raise NotImplementedError
70
-
71
- def fit(self, X, y, weights=None, variable_names=None):
72
- """Search for equations to fit the dataset.
73
-
74
- :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
75
- :type X: np.ndarray/pandas.DataFrame
76
- :param y: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs). Putting in a 2D array will trigger a search for equations for each feature of y.
77
- :type y: np.ndarray
78
- :param weights: Optional. Same shape as y. Each element is how to weight the mean-square-error loss for that particular element of y.
79
- :type weights: np.ndarray
80
- :param variable_names: a list of names for the variables, other than "x0", "x1", etc.
81
- :type variable_names: list
82
- """
83
- if variable_names is None:
84
- if "variable_names" in self.params:
85
- variable_names = self.params["variable_names"]
86
-
87
- self.equations = pysr(
88
- X=X,
89
- y=y,
90
- weights=weights,
91
- variable_names=variable_names,
92
- **{k: v for k, v in self.params.items() if k != "variable_names"},
93
- )
94
- return self
95
-
96
- def refresh(self):
97
- # Updates self.equations with any new options passed,
98
- # such as extra_sympy_mappings.
99
- self.equations = get_hof(**self.params)
100
-
101
- def predict(self, X):
102
- self.refresh()
103
- np_format = self.get_best()["lambda_format"]
104
- return np_format(X)
105
-
106
- def sympy(self):
107
- self.refresh()
108
- return self.get_best()["sympy_format"]
109
-
110
- def latex(self):
111
- self.refresh()
112
- return self.sympy().simplify()
113
-
114
- def jax(self):
115
- self.set_params(output_jax_format=True)
116
- self.refresh()
117
- return self.get_best()["jax_format"]
118
-
119
- def pytorch(self):
120
- self.set_params(output_torch_format=True)
121
- self.refresh()
122
- return self.get_best()["torch_format"]
123
-
124
-
125
- # Add the docs from pysr() to PySRRegressor():
126
- _pysr_docstring_split = []
127
- _start_recording = False
128
- for line in inspect.getdoc(pysr).split("\n"):
129
- # Skip docs on "X" and "y"
130
- if ":param binary_operators:" in line:
131
- _start_recording = True
132
- if ":returns:" in line:
133
- _start_recording = False
134
- if _start_recording:
135
- _pysr_docstring_split.append(line)
136
- _pysr_docstring = "\n\t".join(_pysr_docstring_split)
137
-
138
- PySRRegressor.__init__.__doc__ += _pysr_docstring
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
pysr/sr.py CHANGED
@@ -11,6 +11,7 @@ from pathlib import Path
11
  from datetime import datetime
12
  import warnings
13
  from multiprocessing import cpu_count
 
14
 
15
  is_julia_warning_silenced = False
16
 
@@ -36,20 +37,6 @@ def install(julia_project=None): # pragma: no cover
36
 
37
 
38
  Main = None
39
- global_state = dict(
40
- equation_file="hall_of_fame.csv",
41
- n_features=None,
42
- variable_names=[],
43
- extra_sympy_mappings={},
44
- extra_torch_mappings={},
45
- extra_jax_mappings={},
46
- output_jax_format=False,
47
- output_torch_format=False,
48
- multioutput=False,
49
- nout=1,
50
- selection=None,
51
- raw_julia_output=None,
52
- )
53
 
54
  already_ran = False
55
 
@@ -93,541 +80,14 @@ sympy_mappings = {
93
  }
94
 
95
 
96
- def pysr(
97
- X,
98
- y,
99
- weights=None,
100
- binary_operators=None,
101
- unary_operators=None,
102
- procs=cpu_count(),
103
- loss="L2DistLoss()",
104
- populations=20,
105
- niterations=100,
106
- ncyclesperiteration=300,
107
- alpha=0.1,
108
- annealing=False,
109
- fractionReplaced=0.10,
110
- fractionReplacedHof=0.10,
111
- npop=1000,
112
- parsimony=1e-4,
113
- migration=True,
114
- hofMigration=True,
115
- shouldOptimizeConstants=True,
116
- topn=10,
117
- weightAddNode=1,
118
- weightInsertNode=3,
119
- weightDeleteNode=3,
120
- weightDoNothing=1,
121
- weightMutateConstant=10,
122
- weightMutateOperator=1,
123
- weightRandomize=1,
124
- weightSimplify=0.002,
125
- perturbationFactor=1.0,
126
- extra_sympy_mappings=None,
127
- extra_torch_mappings=None,
128
- extra_jax_mappings=None,
129
- equation_file=None,
130
- verbosity=1e9,
131
- progress=None,
132
- maxsize=20,
133
- fast_cycle=False,
134
- maxdepth=None,
135
- variable_names=None,
136
- batching=False,
137
- batchSize=50,
138
- select_k_features=None,
139
- warmupMaxsizeBy=0.0,
140
- constraints=None,
141
- useFrequency=True,
142
- tempdir=None,
143
- delete_tempfiles=True,
144
- julia_project=None,
145
- update=True,
146
- temp_equation_file=False,
147
- output_jax_format=False,
148
- output_torch_format=False,
149
- optimizer_algorithm="BFGS",
150
- optimizer_nrestarts=3,
151
- optimize_probability=1.0,
152
- optimizer_iterations=10,
153
- tournament_selection_n=10,
154
- tournament_selection_p=1.0,
155
- denoise=False,
156
- Xresampled=None,
157
- precision=32,
158
- multithreading=None,
159
- **kwargs,
160
- ):
161
- """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
162
- Note: most default parameters have been tuned over several example
163
- equations, but you should adjust `niterations`,
164
- `binary_operators`, `unary_operators` to your requirements.
165
- You can view more detailed explanations of the options on the
166
- [options page](https://pysr.readthedocs.io/en/latest/docs/options/) of the documentation.
167
-
168
- :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
169
- :type X: np.ndarray/pandas.DataFrame
170
- :param y: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs). Putting in a 2D array will trigger a search for equations for each feature of y.
171
- :type y: np.ndarray
172
- :param weights: same shape as y. Each element is how to weight the mean-square-error loss for that particular element of y.
173
- :type weights: np.ndarray
174
- :param binary_operators: List of strings giving the binary operators in Julia's Base. Default is ["+", "-", "*", "/",].
175
- :type binary_operators: list
176
- :param unary_operators: Same but for operators taking a single scalar. Default is [].
177
- :type unary_operators: list
178
- :param procs: Number of processes (=number of populations running).
179
- :type procs: int
180
- :param loss: String of Julia code specifying the loss function. Can either be a loss from LossFunctions.jl, or your own loss written as a function. Examples of custom written losses include: `myloss(x, y) = abs(x-y)` for non-weighted, or `myloss(x, y, w) = w*abs(x-y)` for weighted. Among the included losses, these are as follows. Regression: `LPDistLoss{P}()`, `L1DistLoss()`, `L2DistLoss()` (mean square), `LogitDistLoss()`, `HuberLoss(d)`, `L1EpsilonInsLoss(ϵ)`, `L2EpsilonInsLoss(ϵ)`, `PeriodicLoss(c)`, `QuantileLoss(τ)`. Classification: `ZeroOneLoss()`, `PerceptronLoss()`, `L1HingeLoss()`, `SmoothedL1HingeLoss(γ)`, `ModifiedHuberLoss()`, `L2MarginLoss()`, `ExpLoss()`, `SigmoidLoss()`, `DWDMarginLoss(q)`.
181
- :type loss: str
182
- :param populations: Number of populations running.
183
- :type populations: int
184
- :param niterations: Number of iterations of the algorithm to run. The best equations are printed, and migrate between populations, at the end of each.
185
- :type niterations: int
186
- :param ncyclesperiteration: Number of total mutations to run, per 10 samples of the population, per iteration.
187
- :type ncyclesperiteration: int
188
- :param alpha: Initial temperature.
189
- :type alpha: float
190
- :param annealing: Whether to use annealing. You should (and it is default).
191
- :type annealing: bool
192
- :param fractionReplaced: How much of population to replace with migrating equations from other populations.
193
- :type fractionReplaced: float
194
- :param fractionReplacedHof: How much of population to replace with migrating equations from hall of fame.
195
- :type fractionReplacedHof: float
196
- :param npop: Number of individuals in each population
197
- :type npop: int
198
- :param parsimony: Multiplicative factor for how much to punish complexity.
199
- :type parsimony: float
200
- :param migration: Whether to migrate.
201
- :type migration: bool
202
- :param hofMigration: Whether to have the hall of fame migrate.
203
- :type hofMigration: bool
204
- :param shouldOptimizeConstants: Whether to numerically optimize constants (Nelder-Mead/Newton) at the end of each iteration.
205
- :type shouldOptimizeConstants: bool
206
- :param topn: How many top individuals migrate from each population.
207
- :type topn: int
208
- :param perturbationFactor: Constants are perturbed by a max factor of (perturbationFactor*T + 1). Either multiplied by this or divided by this.
209
- :type perturbationFactor: float
210
- :param weightAddNode: Relative likelihood for mutation to add a node
211
- :type weightAddNode: float
212
- :param weightInsertNode: Relative likelihood for mutation to insert a node
213
- :type weightInsertNode: float
214
- :param weightDeleteNode: Relative likelihood for mutation to delete a node
215
- :type weightDeleteNode: float
216
- :param weightDoNothing: Relative likelihood for mutation to leave the individual
217
- :type weightDoNothing: float
218
- :param weightMutateConstant: Relative likelihood for mutation to change the constant slightly in a random direction.
219
- :type weightMutateConstant: float
220
- :param weightMutateOperator: Relative likelihood for mutation to swap an operator.
221
- :type weightMutateOperator: float
222
- :param weightRandomize: Relative likelihood for mutation to completely delete and then randomly generate the equation
223
- :type weightRandomize: float
224
- :param weightSimplify: Relative likelihood for mutation to simplify constant parts by evaluation
225
- :type weightSimplify: float
226
- :param equation_file: Where to save the files (.csv separated by |)
227
- :type equation_file: str
228
- :param verbosity: What verbosity level to use. 0 means minimal print statements.
229
- :type verbosity: int
230
- :param progress: Whether to use a progress bar instead of printing to stdout.
231
- :type progress: bool
232
- :param maxsize: Max size of an equation.
233
- :type maxsize: int
234
- :param maxdepth: Max depth of an equation. You can use both maxsize and maxdepth. maxdepth is by default set to = maxsize, which means that it is redundant.
235
- :type maxdepth: int
236
- :param fast_cycle: (experimental) - batch over population subsamples. This is a slightly different algorithm than regularized evolution, but does cycles 15% faster. May be algorithmically less efficient.
237
- :type fast_cycle: bool
238
- :param variable_names: a list of names for the variables, other than "x0", "x1", etc.
239
- :type variable_names: list
240
- :param batching: whether to compare population members on small batches during evolution. Still uses full dataset for comparing against hall of fame.
241
- :type batching: bool
242
- :param batchSize: the amount of data to use if doing batching.
243
- :type batchSize: int
244
- :param select_k_features: whether to run feature selection in Python using random forests, before passing to the symbolic regression code. None means no feature selection; an int means select that many features.
245
- :type select_k_features: None/int
246
- :param warmupMaxsizeBy: whether to slowly increase max size from a small number up to the maxsize (if greater than 0). If greater than 0, says the fraction of training time at which the current maxsize will reach the user-passed maxsize.
247
- :type warmupMaxsizeBy: float
248
- :param constraints: dictionary of int (unary) or 2-tuples (binary), this enforces maxsize constraints on the individual arguments of operators. E.g., `'pow': (-1, 1)` says that power laws can have any complexity left argument, but only 1 complexity exponent. Use this to force more interpretable solutions.
249
- :type constraints: dict
250
- :param useFrequency: whether to measure the frequency of complexities, and use that instead of parsimony to explore equation space. Will naturally find equations of all complexities.
251
- :type useFrequency: bool
252
- :param tempdir: directory for the temporary files
253
- :type tempdir: str/None
254
- :param delete_tempfiles: whether to delete the temporary files after finishing
255
- :type delete_tempfiles: bool
256
- :param julia_project: a Julia environment location containing a Project.toml (and potentially the source code for SymbolicRegression.jl). Default gives the Python package directory, where a Project.toml file should be present from the install.
257
- :type julia_project: str/None
258
- :param update: Whether to automatically update Julia packages.
259
- :type update: bool
260
- :param temp_equation_file: Whether to put the hall of fame file in the temp directory. Deletion is then controlled with the delete_tempfiles argument.
261
- :type temp_equation_file: bool
262
- :param output_jax_format: Whether to create a 'jax_format' column in the output, containing jax-callable functions and the default parameters in a jax array.
263
- :type output_jax_format: bool
264
- :param output_torch_format: Whether to create a 'torch_format' column in the output, containing a torch module with trainable parameters.
265
- :type output_torch_format: bool
266
- :param tournament_selection_n: Number of expressions to consider in each tournament.
267
- :type tournament_selection_n: int
268
- :param tournament_selection_p: Probability of selecting the best expression in each tournament. The probability will decay as p*(1-p)^n for other expressions, sorted by loss.
269
- :type tournament_selection_p: float
270
- :param denoise: Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data.
271
- :type denoise: bool
272
- :param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
273
- :type precision: int
274
- :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
275
- :type multithreading: bool
276
- :param **kwargs: Other options passed to SymbolicRegression.Options, for example, if you modify SymbolicRegression.jl to include additional arguments.
277
- :type **kwargs: dict
278
- :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
279
- :type: pd.DataFrame/list
280
- """
281
- global already_ran
282
-
283
- if binary_operators is None:
284
- binary_operators = "+ * - /".split(" ")
285
- if unary_operators is None:
286
- unary_operators = []
287
- if extra_sympy_mappings is None:
288
- extra_sympy_mappings = {}
289
- if variable_names is None:
290
- variable_names = []
291
- if constraints is None:
292
- constraints = {}
293
- if multithreading is None:
294
- # Default is multithreading=True, unless explicitly set,
295
- # or procs is set to 0 (serial mode).
296
- multithreading = procs != 0
297
-
298
- global Main
299
- if Main is None:
300
- if multithreading:
301
- os.environ["JULIA_NUM_THREADS"] = str(procs)
302
-
303
- Main = init_julia()
304
-
305
- buffer_available = "buffer" in sys.stdout.__dir__()
306
-
307
- if progress is not None:
308
- if progress and not buffer_available:
309
- warnings.warn(
310
- "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
311
- )
312
- progress = False
313
- else:
314
- progress = buffer_available
315
-
316
- assert optimizer_algorithm in ["NelderMead", "BFGS"]
317
- assert tournament_selection_n < npop
318
-
319
- if isinstance(X, pd.DataFrame):
320
- variable_names = list(X.columns)
321
- X = np.array(X)
322
-
323
- if len(X.shape) == 1:
324
- X = X[:, None]
325
-
326
- assert not isinstance(y, pd.DataFrame)
327
-
328
- if len(variable_names) == 0:
329
- variable_names = [f"x{i}" for i in range(X.shape[1])]
330
-
331
- if extra_jax_mappings is not None:
332
- for value in extra_jax_mappings.values():
333
- if not isinstance(value, str):
334
- raise NotImplementedError(
335
- "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
336
- )
337
-
338
- if extra_torch_mappings is not None:
339
- for value in extra_jax_mappings.values():
340
- if not callable(value):
341
- raise NotImplementedError(
342
- "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
343
- )
344
-
345
- use_custom_variable_names = len(variable_names) != 0
346
- # TODO: this is always true.
347
-
348
- _check_assertions(
349
- X,
350
- binary_operators,
351
- unary_operators,
352
- use_custom_variable_names,
353
- variable_names,
354
- weights,
355
- y,
356
- )
357
-
358
- if len(X) > 10000 and not batching:
359
- warnings.warn(
360
- "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (https://pysr.readthedocs.io/en/latest/docs/options/#batching). You should also reconsider if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form with symbolic regression. More datapoints will lower the search speed."
361
- )
362
-
363
- if maxsize > 40:
364
- warnings.warn(
365
- "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`."
366
- )
367
- if maxsize < 7:
368
- raise NotImplementedError("PySR requires a maxsize of at least 7")
369
-
370
- X, selection = _handle_feature_selection(X, select_k_features, y, variable_names)
371
-
372
- if maxdepth is None:
373
- maxdepth = maxsize
374
- if isinstance(binary_operators, str):
375
- binary_operators = [binary_operators]
376
- if isinstance(unary_operators, str):
377
- unary_operators = [unary_operators]
378
-
379
- if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
380
- multioutput = False
381
- nout = 1
382
- y = y.reshape(-1)
383
- elif len(y.shape) == 2:
384
- multioutput = True
385
- nout = y.shape[1]
386
- else:
387
- raise NotImplementedError("y shape not supported!")
388
-
389
- if denoise:
390
- if weights is not None:
391
- raise NotImplementedError(
392
- "No weights for denoising - the weights are learned."
393
- )
394
- if Xresampled is not None:
395
- # Select among only the selected features:
396
- if isinstance(Xresampled, pd.DataFrame):
397
- # Handle Xresampled is pandas dataframe
398
- if selection is not None:
399
- Xresampled = Xresampled[[variable_names[i] for i in selection]]
400
- else:
401
- Xresampled = Xresampled[variable_names]
402
- Xresampled = np.array(Xresampled)
403
- else:
404
- if selection is not None:
405
- Xresampled = Xresampled[:, selection]
406
- if multioutput:
407
- y = np.stack(
408
- [_denoise(X, y[:, i], Xresampled=Xresampled)[1] for i in range(nout)],
409
- axis=1,
410
- )
411
- if Xresampled is not None:
412
- X = Xresampled
413
- else:
414
- X, y = _denoise(X, y, Xresampled=Xresampled)
415
-
416
- julia_project = _get_julia_project(julia_project)
417
-
418
- tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
419
-
420
- if temp_equation_file:
421
- equation_file = tmpdir / "hall_of_fame.csv"
422
- elif equation_file is None:
423
- date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
424
- equation_file = "hall_of_fame_" + date_time + ".csv"
425
-
426
- _create_inline_operators(
427
- binary_operators=binary_operators, unary_operators=unary_operators
428
- )
429
- _handle_constraints(
430
- binary_operators=binary_operators,
431
- unary_operators=unary_operators,
432
- constraints=constraints,
433
- )
434
-
435
- una_constraints = [constraints[op] for op in unary_operators]
436
- bin_constraints = [constraints[op] for op in binary_operators]
437
-
438
- try:
439
- # TODO: is this needed since Julia now prints directly to stdout?
440
- term_width = shutil.get_terminal_size().columns
441
- except:
442
- _, term_width = subprocess.check_output(["stty", "size"]).split()
443
-
444
- if not already_ran:
445
- from julia import Pkg
446
-
447
- Pkg.activate(f"{_escape_filename(julia_project)}")
448
- try:
449
- if update:
450
- Pkg.resolve()
451
- Pkg.instantiate()
452
- else:
453
- Pkg.instantiate()
454
- except RuntimeError as e:
455
- raise ImportError(
456
- f"""
457
- Required dependencies are not installed or built. Run the following code in the Python REPL:
458
-
459
- >>> import pysr
460
- >>> pysr.install()
461
-
462
- Tried to activate project {julia_project} but failed."""
463
- ) from e
464
- Main.eval("using SymbolicRegression")
465
-
466
- Main.plus = Main.eval("(+)")
467
- Main.sub = Main.eval("(-)")
468
- Main.mult = Main.eval("(*)")
469
- Main.pow = Main.eval("(^)")
470
- Main.div = Main.eval("(/)")
471
-
472
- Main.custom_loss = Main.eval(loss)
473
-
474
- mutationWeights = [
475
- float(weightMutateConstant),
476
- float(weightMutateOperator),
477
- float(weightAddNode),
478
- float(weightInsertNode),
479
- float(weightDeleteNode),
480
- float(weightSimplify),
481
- float(weightRandomize),
482
- float(weightDoNothing),
483
- ]
484
-
485
- options = Main.Options(
486
- binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
487
- unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
488
- bin_constraints=bin_constraints,
489
- una_constraints=una_constraints,
490
- parsimony=float(parsimony),
491
- loss=Main.custom_loss,
492
- alpha=float(alpha),
493
- maxsize=int(maxsize),
494
- maxdepth=int(maxdepth),
495
- fast_cycle=fast_cycle,
496
- migration=migration,
497
- hofMigration=hofMigration,
498
- fractionReplacedHof=float(fractionReplacedHof),
499
- shouldOptimizeConstants=shouldOptimizeConstants,
500
- hofFile=_escape_filename(equation_file),
501
- npopulations=int(populations),
502
- optimizer_algorithm=optimizer_algorithm,
503
- optimizer_nrestarts=int(optimizer_nrestarts),
504
- optimize_probability=float(optimize_probability),
505
- optimizer_iterations=int(optimizer_iterations),
506
- perturbationFactor=float(perturbationFactor),
507
- annealing=annealing,
508
- batching=batching,
509
- batchSize=int(min([batchSize, len(X)]) if batching else len(X)),
510
- mutationWeights=mutationWeights,
511
- warmupMaxsizeBy=float(warmupMaxsizeBy),
512
- useFrequency=useFrequency,
513
- npop=int(npop),
514
- ns=int(tournament_selection_n),
515
- probPickFirst=float(tournament_selection_p),
516
- ncyclesperiteration=int(ncyclesperiteration),
517
- fractionReplaced=float(fractionReplaced),
518
- topn=int(topn),
519
- verbosity=int(verbosity),
520
- progress=progress,
521
- terminal_width=int(term_width),
522
- **kwargs,
523
- )
524
-
525
- np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[precision]
526
-
527
- Main.X = np.array(X, dtype=np_dtype).T
528
- if len(y.shape) == 1:
529
- Main.y = np.array(y, dtype=np_dtype)
530
- else:
531
- Main.y = np.array(y, dtype=np_dtype).T
532
- if weights is not None:
533
- if len(weights.shape) == 1:
534
- Main.weights = np.array(weights, dtype=np_dtype)
535
- else:
536
- Main.weights = np.array(weights, dtype=np_dtype).T
537
- else:
538
- Main.weights = None
539
-
540
- cprocs = 0 if multithreading else procs
541
-
542
- raw_julia_output = Main.EquationSearch(
543
- Main.X,
544
- Main.y,
545
- weights=Main.weights,
546
- niterations=int(niterations),
547
- varMap=(
548
- variable_names
549
- if selection is None
550
- else [variable_names[i] for i in selection]
551
- ),
552
- options=options,
553
- numprocs=int(cprocs),
554
- multithreading=bool(multithreading),
555
- )
556
-
557
- _set_globals(
558
- X=X,
559
- equation_file=equation_file,
560
- variable_names=variable_names,
561
- extra_sympy_mappings=extra_sympy_mappings,
562
- extra_torch_mappings=extra_torch_mappings,
563
- extra_jax_mappings=extra_jax_mappings,
564
- output_jax_format=output_jax_format,
565
- output_torch_format=output_torch_format,
566
- multioutput=multioutput,
567
- nout=nout,
568
- selection=selection,
569
- raw_julia_output=raw_julia_output,
570
- )
571
-
572
- equations = get_hof(
573
- equation_file=equation_file,
574
- n_features=X.shape[1],
575
- variable_names=variable_names,
576
- output_jax_format=output_jax_format,
577
- output_torch_format=output_torch_format,
578
- selection=selection,
579
- extra_sympy_mappings=extra_sympy_mappings,
580
- extra_jax_mappings=extra_jax_mappings,
581
- extra_torch_mappings=extra_torch_mappings,
582
- multioutput=multioutput,
583
- nout=nout,
584
  )
585
-
586
- if delete_tempfiles:
587
- shutil.rmtree(tmpdir)
588
-
589
- already_ran = True
590
-
591
- return equations
592
-
593
-
594
- def _set_globals(
595
- *,
596
- equation_file,
597
- variable_names,
598
- extra_sympy_mappings,
599
- extra_torch_mappings,
600
- extra_jax_mappings,
601
- output_jax_format,
602
- output_torch_format,
603
- multioutput,
604
- nout,
605
- selection,
606
- raw_julia_output,
607
- X=None,
608
- n_features=None
609
- ):
610
- global global_state
611
-
612
- if n_features is None and X is not None:
613
- global_state["n_features"] = X.shape[1]
614
- elif X is None and n_features is not None:
615
- global_state["n_features"] = n_features
616
- elif X is not None and n_features is not None:
617
- assert X.shape[1] == n_features
618
- global_state["n_features"] = n_features
619
-
620
- global_state["equation_file"] = equation_file
621
- global_state["variable_names"] = variable_names
622
- global_state["extra_sympy_mappings"] = extra_sympy_mappings
623
- global_state["extra_torch_mappings"] = extra_torch_mappings
624
- global_state["extra_jax_mappings"] = extra_jax_mappings
625
- global_state["output_jax_format"] = output_jax_format
626
- global_state["output_torch_format"] = output_torch_format
627
- global_state["multioutput"] = multioutput
628
- global_state["nout"] = nout
629
- global_state["selection"] = selection
630
- global_state["raw_julia_output"] = raw_julia_output
631
 
632
 
633
  def _handle_constraints(binary_operators, unary_operators, constraints):
@@ -654,6 +114,7 @@ def _handle_constraints(binary_operators, unary_operators, constraints):
654
 
655
 
656
  def _create_inline_operators(binary_operators, unary_operators):
 
657
  for op_list in [binary_operators, unary_operators]:
658
  for i, op in enumerate(op_list):
659
  is_user_defined_operator = "(" in op
@@ -718,215 +179,6 @@ def run_feature_selection(X, y, select_k_features):
718
  return selector.get_support(indices=True)
719
 
720
 
721
- def get_hof(
722
- equation_file=None,
723
- n_features=None,
724
- variable_names=None,
725
- output_jax_format=None,
726
- output_torch_format=None,
727
- selection=None,
728
- extra_sympy_mappings=None,
729
- extra_jax_mappings=None,
730
- extra_torch_mappings=None,
731
- multioutput=None,
732
- nout=None,
733
- **kwargs,
734
- ):
735
- """Get the equations from a hall of fame file. If no arguments
736
- entered, the ones used previously from a call to PySR will be used."""
737
-
738
- global global_state
739
-
740
- if equation_file is None:
741
- equation_file = global_state["equation_file"]
742
- if n_features is None:
743
- n_features = global_state["n_features"]
744
- if variable_names is None:
745
- variable_names = global_state["variable_names"]
746
- if extra_sympy_mappings is None:
747
- extra_sympy_mappings = global_state["extra_sympy_mappings"]
748
- if extra_jax_mappings is None:
749
- extra_jax_mappings = global_state["extra_jax_mappings"]
750
- if extra_torch_mappings is None:
751
- extra_torch_mappings = global_state["extra_torch_mappings"]
752
- if output_torch_format is None:
753
- output_torch_format = global_state["output_torch_format"]
754
- if output_jax_format is None:
755
- output_jax_format = global_state["output_jax_format"]
756
- if multioutput is None:
757
- multioutput = global_state["multioutput"]
758
- if nout is None:
759
- nout = global_state["nout"]
760
- if selection is None:
761
- selection = global_state["selection"]
762
-
763
- global_state["selection"] = selection
764
- global_state["equation_file"] = equation_file
765
- global_state["n_features"] = n_features
766
- global_state["variable_names"] = variable_names
767
- global_state["extra_sympy_mappings"] = extra_sympy_mappings
768
- global_state["extra_jax_mappings"] = extra_jax_mappings
769
- global_state["extra_torch_mappings"] = extra_torch_mappings
770
- global_state["output_torch_format"] = output_torch_format
771
- global_state["output_jax_format"] = output_jax_format
772
- global_state["multioutput"] = multioutput
773
- global_state["nout"] = nout
774
- global_state["selection"] = selection
775
-
776
- try:
777
- if multioutput:
778
- all_outputs = [
779
- pd.read_csv(str(equation_file) + f".out{i}" + ".bkup", sep="|")
780
- for i in range(1, nout + 1)
781
- ]
782
- else:
783
- all_outputs = [pd.read_csv(str(equation_file) + ".bkup", sep="|")]
784
- except FileNotFoundError:
785
- raise RuntimeError(
786
- "Couldn't find equation file! The equation search likely exited before a single iteration completed."
787
- )
788
-
789
- ret_outputs = []
790
-
791
- for output in all_outputs:
792
-
793
- scores = []
794
- lastMSE = None
795
- lastComplexity = 0
796
- sympy_format = []
797
- lambda_format = []
798
- if output_jax_format:
799
- jax_format = []
800
- if output_torch_format:
801
- torch_format = []
802
- use_custom_variable_names = len(variable_names) != 0
803
- local_sympy_mappings = {**extra_sympy_mappings, **sympy_mappings}
804
-
805
- if use_custom_variable_names:
806
- sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
807
- else:
808
- sympy_symbols = [sympy.Symbol("x%d" % i) for i in range(n_features)]
809
-
810
- for _, eqn_row in output.iterrows():
811
- eqn = sympify(eqn_row["Equation"], locals=local_sympy_mappings)
812
- sympy_format.append(eqn)
813
-
814
- # Numpy:
815
- lambda_format.append(
816
- CallableEquation(sympy_symbols, eqn, selection, variable_names)
817
- )
818
-
819
- # JAX:
820
- if output_jax_format:
821
- from .export_jax import sympy2jax
822
-
823
- func, params = sympy2jax(
824
- eqn,
825
- sympy_symbols,
826
- selection=selection,
827
- extra_jax_mappings=extra_jax_mappings,
828
- )
829
- jax_format.append({"callable": func, "parameters": params})
830
-
831
- # Torch:
832
- if output_torch_format:
833
- from .export_torch import sympy2torch
834
-
835
- module = sympy2torch(
836
- eqn,
837
- sympy_symbols,
838
- selection=selection,
839
- extra_torch_mappings=extra_torch_mappings,
840
- )
841
- torch_format.append(module)
842
-
843
- curMSE = eqn_row["MSE"]
844
- curComplexity = eqn_row["Complexity"]
845
-
846
- if lastMSE is None:
847
- cur_score = 0.0
848
- else:
849
- if curMSE > 0.0:
850
- cur_score = -np.log(curMSE / lastMSE) / (
851
- curComplexity - lastComplexity
852
- )
853
- else:
854
- cur_score = np.inf
855
-
856
- scores.append(cur_score)
857
- lastMSE = curMSE
858
- lastComplexity = curComplexity
859
-
860
- output["score"] = np.array(scores)
861
- output["sympy_format"] = sympy_format
862
- output["lambda_format"] = lambda_format
863
- output_cols = [
864
- "Complexity",
865
- "MSE",
866
- "score",
867
- "Equation",
868
- "sympy_format",
869
- "lambda_format",
870
- ]
871
- if output_jax_format:
872
- output_cols += ["jax_format"]
873
- output["jax_format"] = jax_format
874
- if output_torch_format:
875
- output_cols += ["torch_format"]
876
- output["torch_format"] = torch_format
877
-
878
- ret_outputs.append(output[output_cols])
879
-
880
- if multioutput:
881
- return ret_outputs
882
- return ret_outputs[0]
883
-
884
-
885
- def best_row(equations=None):
886
- """Return the best row of a hall of fame file using the score column.
887
- By default this uses the last equation file.
888
- """
889
- if equations is None:
890
- equations = get_hof()
891
- if isinstance(equations, list):
892
- return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
893
- return equations.iloc[np.argmax(equations["score"])]
894
-
895
-
896
- def best_tex(equations=None):
897
- """Return the equation with the best score, in latex format
898
- By default this uses the last equation file.
899
- """
900
- if equations is None:
901
- equations = get_hof()
902
- if isinstance(equations, list):
903
- return [
904
- sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
905
- ]
906
- return sympy.latex(best_row(equations)["sympy_format"].simplify())
907
-
908
-
909
- def best(equations=None):
910
- """Return the equation with the best score, in sympy format.
911
- By default this uses the last equation file.
912
- """
913
- if equations is None:
914
- equations = get_hof()
915
- if isinstance(equations, list):
916
- return [best_row(eq)["sympy_format"].simplify() for eq in equations]
917
- return best_row(equations)["sympy_format"].simplify()
918
-
919
-
920
- def best_callable(equations=None):
921
- """Return the equation with the best score, in callable format.
922
- By default this uses the last equation file.
923
- """
924
- if equations is None:
925
- equations = get_hof()
926
- if isinstance(equations, list):
927
- return [best_row(eq)["lambda_format"] for eq in equations]
928
- return best_row(equations)["lambda_format"]
929
-
930
 
931
  def _escape_filename(filename):
932
  """Turns a file into a string representation with correctly escaped backslashes"""
@@ -934,6 +186,14 @@ def _escape_filename(filename):
934
  str_repr = str_repr.replace("\\", "\\\\")
935
  return str_repr
936
 
 
 
 
 
 
 
 
 
937
 
938
  # https://gist.github.com/garrettdreyfus/8153571
939
  def _yesno(question):
@@ -1061,3 +321,844 @@ julia = "1.5"
1061
 
1062
  project_toml_path = tmp_dir / "Project.toml"
1063
  project_toml_path.write_text(project_toml)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  from datetime import datetime
12
  import warnings
13
  from multiprocessing import cpu_count
14
+ from sklearn.base import BaseEstimator, RegressorMixin
15
 
16
  is_julia_warning_silenced = False
17
 
 
37
 
38
 
39
  Main = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
 
41
  already_ran = False
42
 
 
80
  }
81
 
82
 
83
+ def pysr(X, y, weights=None, **kwargs):
84
+ warnings.warn(
85
+ "Calling `pysr` is deprecated. Please use `model = PySRRegressor(**params); model.fit(X, y)` going forward.",
86
+ DeprecationWarning,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  )
88
+ model = PySRRegressor(**kwargs)
89
+ model.fit(X, y, weights=weights)
90
+ return model.equations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
 
93
  def _handle_constraints(binary_operators, unary_operators, constraints):
 
114
 
115
 
116
  def _create_inline_operators(binary_operators, unary_operators):
117
+ global Main
118
  for op_list in [binary_operators, unary_operators]:
119
  for i, op in enumerate(op_list):
120
  is_user_defined_operator = "(" in op
 
179
  return selector.get_support(indices=True)
180
 
181
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
182
 
183
  def _escape_filename(filename):
184
  """Turns a file into a string representation with correctly escaped backslashes"""
 
186
  str_repr = str_repr.replace("\\", "\\\\")
187
  return str_repr
188
 
189
+ def best(*args, **kwargs):
190
+ raise NotImplementedError("`best` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can return `.sympy()` to get the sympy representation of the best equation.")
191
+
192
+ def best_tex(*args, **kwargs):
193
+ raise NotImplementedError("`best_tex` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can return `.latex()` to get the sympy representation of the best equation.")
194
+
195
+ def best_callable(*args, **kwargs):
196
+ raise NotImplementedError("`best_callable` has been deprecated. Please use the `PySRRegressor` interface. After fitting, you can use `.predict(X)` to use the best callable.")
197
 
198
  # https://gist.github.com/garrettdreyfus/8153571
199
  def _yesno(question):
 
321
 
322
  project_toml_path = tmp_dir / "Project.toml"
323
  project_toml_path.write_text(project_toml)
324
+
325
+
326
+ class PySRRegressor(BaseEstimator, RegressorMixin):
327
+ def __init__(
328
+ self,
329
+ model_selection="accuracy",
330
+ weights=None,
331
+ binary_operators=None,
332
+ unary_operators=None,
333
+ procs=cpu_count(),
334
+ loss="L2DistLoss()",
335
+ populations=20,
336
+ niterations=100,
337
+ ncyclesperiteration=300,
338
+ alpha=0.1,
339
+ annealing=False,
340
+ fractionReplaced=0.10,
341
+ fractionReplacedHof=0.10,
342
+ npop=1000,
343
+ parsimony=1e-4,
344
+ migration=True,
345
+ hofMigration=True,
346
+ shouldOptimizeConstants=True,
347
+ topn=10,
348
+ weightAddNode=1,
349
+ weightInsertNode=3,
350
+ weightDeleteNode=3,
351
+ weightDoNothing=1,
352
+ weightMutateConstant=10,
353
+ weightMutateOperator=1,
354
+ weightRandomize=1,
355
+ weightSimplify=0.002,
356
+ perturbationFactor=1.0,
357
+ extra_sympy_mappings=None,
358
+ extra_torch_mappings=None,
359
+ extra_jax_mappings=None,
360
+ equation_file=None,
361
+ verbosity=1e9,
362
+ progress=None,
363
+ maxsize=20,
364
+ fast_cycle=False,
365
+ maxdepth=None,
366
+ variable_names=None,
367
+ batching=False,
368
+ batchSize=50,
369
+ select_k_features=None,
370
+ warmupMaxsizeBy=0.0,
371
+ constraints=None,
372
+ useFrequency=True,
373
+ tempdir=None,
374
+ delete_tempfiles=True,
375
+ julia_project=None,
376
+ update=True,
377
+ temp_equation_file=False,
378
+ output_jax_format=False,
379
+ output_torch_format=False,
380
+ optimizer_algorithm="BFGS",
381
+ optimizer_nrestarts=3,
382
+ optimize_probability=1.0,
383
+ optimizer_iterations=10,
384
+ tournament_selection_n=10,
385
+ tournament_selection_p=1.0,
386
+ denoise=False,
387
+ Xresampled=None,
388
+ precision=32,
389
+ multithreading=None,
390
+ **kwargs,
391
+ ):
392
+ """Initialize settings for an equation search in PySR.
393
+
394
+ Note: most default parameters have been tuned over several example
395
+ equations, but you should adjust `niterations`,
396
+ `binary_operators`, `unary_operators` to your requirements.
397
+ You can view more detailed explanations of the options on the
398
+ [options page](https://pysr.readthedocs.io/en/latest/docs/options/) of the documentation.
399
+
400
+ :param model_selection: How to select a model. Can be 'accuracy' or 'best'. 'best' will optimize a combination of complexity and accuracy.
401
+ :type model_selection: str
402
+ :param binary_operators: List of strings giving the binary operators in Julia's Base. Default is ["+", "-", "*", "/",].
403
+ :type binary_operators: list
404
+ :param unary_operators: Same but for operators taking a single scalar. Default is [].
405
+ :type unary_operators: list
406
+ :param procs: Number of processes (=number of populations running).
407
+ :type procs: int
408
+ :param loss: String of Julia code specifying the loss function. Can either be a loss from LossFunctions.jl, or your own loss written as a function. Examples of custom written losses include: `myloss(x, y) = abs(x-y)` for non-weighted, or `myloss(x, y, w) = w*abs(x-y)` for weighted. Among the included losses, these are as follows. Regression: `LPDistLoss{P}()`, `L1DistLoss()`, `L2DistLoss()` (mean square), `LogitDistLoss()`, `HuberLoss(d)`, `L1EpsilonInsLoss(ϵ)`, `L2EpsilonInsLoss(ϵ)`, `PeriodicLoss(c)`, `QuantileLoss(τ)`. Classification: `ZeroOneLoss()`, `PerceptronLoss()`, `L1HingeLoss()`, `SmoothedL1HingeLoss(γ)`, `ModifiedHuberLoss()`, `L2MarginLoss()`, `ExpLoss()`, `SigmoidLoss()`, `DWDMarginLoss(q)`.
409
+ :type loss: str
410
+ :param populations: Number of populations running.
411
+ :type populations: int
412
+ :param niterations: Number of iterations of the algorithm to run. The best equations are printed, and migrate between populations, at the end of each.
413
+ :type niterations: int
414
+ :param ncyclesperiteration: Number of total mutations to run, per 10 samples of the population, per iteration.
415
+ :type ncyclesperiteration: int
416
+ :param alpha: Initial temperature.
417
+ :type alpha: float
418
+ :param annealing: Whether to use annealing. You should (and it is default).
419
+ :type annealing: bool
420
+ :param fractionReplaced: How much of population to replace with migrating equations from other populations.
421
+ :type fractionReplaced: float
422
+ :param fractionReplacedHof: How much of population to replace with migrating equations from hall of fame.
423
+ :type fractionReplacedHof: float
424
+ :param npop: Number of individuals in each population
425
+ :type npop: int
426
+ :param parsimony: Multiplicative factor for how much to punish complexity.
427
+ :type parsimony: float
428
+ :param migration: Whether to migrate.
429
+ :type migration: bool
430
+ :param hofMigration: Whether to have the hall of fame migrate.
431
+ :type hofMigration: bool
432
+ :param shouldOptimizeConstants: Whether to numerically optimize constants (Nelder-Mead/Newton) at the end of each iteration.
433
+ :type shouldOptimizeConstants: bool
434
+ :param topn: How many top individuals migrate from each population.
435
+ :type topn: int
436
+ :param perturbationFactor: Constants are perturbed by a max factor of (perturbationFactor*T + 1). Either multiplied by this or divided by this.
437
+ :type perturbationFactor: float
438
+ :param weightAddNode: Relative likelihood for mutation to add a node
439
+ :type weightAddNode: float
440
+ :param weightInsertNode: Relative likelihood for mutation to insert a node
441
+ :type weightInsertNode: float
442
+ :param weightDeleteNode: Relative likelihood for mutation to delete a node
443
+ :type weightDeleteNode: float
444
+ :param weightDoNothing: Relative likelihood for mutation to leave the individual
445
+ :type weightDoNothing: float
446
+ :param weightMutateConstant: Relative likelihood for mutation to change the constant slightly in a random direction.
447
+ :type weightMutateConstant: float
448
+ :param weightMutateOperator: Relative likelihood for mutation to swap an operator.
449
+ :type weightMutateOperator: float
450
+ :param weightRandomize: Relative likelihood for mutation to completely delete and then randomly generate the equation
451
+ :type weightRandomize: float
452
+ :param weightSimplify: Relative likelihood for mutation to simplify constant parts by evaluation
453
+ :type weightSimplify: float
454
+ :param equation_file: Where to save the files (.csv separated by |)
455
+ :type equation_file: str
456
+ :param verbosity: What verbosity level to use. 0 means minimal print statements.
457
+ :type verbosity: int
458
+ :param progress: Whether to use a progress bar instead of printing to stdout.
459
+ :type progress: bool
460
+ :param maxsize: Max size of an equation.
461
+ :type maxsize: int
462
+ :param maxdepth: Max depth of an equation. You can use both maxsize and maxdepth. maxdepth is by default set to = maxsize, which means that it is redundant.
463
+ :type maxdepth: int
464
+ :param fast_cycle: (experimental) - batch over population subsamples. This is a slightly different algorithm than regularized evolution, but does cycles 15% faster. May be algorithmically less efficient.
465
+ :type fast_cycle: bool
466
+ :param variable_names: a list of names for the variables, other than "x0", "x1", etc.
467
+ :type variable_names: list
468
+ :param batching: whether to compare population members on small batches during evolution. Still uses full dataset for comparing against hall of fame.
469
+ :type batching: bool
470
+ :param batchSize: the amount of data to use if doing batching.
471
+ :type batchSize: int
472
+ :param select_k_features: whether to run feature selection in Python using random forests, before passing to the symbolic regression code. None means no feature selection; an int means select that many features.
473
+ :type select_k_features: None/int
474
+ :param warmupMaxsizeBy: whether to slowly increase max size from a small number up to the maxsize (if greater than 0). If greater than 0, says the fraction of training time at which the current maxsize will reach the user-passed maxsize.
475
+ :type warmupMaxsizeBy: float
476
+ :param constraints: dictionary of int (unary) or 2-tuples (binary), this enforces maxsize constraints on the individual arguments of operators. E.g., `'pow': (-1, 1)` says that power laws can have any complexity left argument, but only 1 complexity exponent. Use this to force more interpretable solutions.
477
+ :type constraints: dict
478
+ :param useFrequency: whether to measure the frequency of complexities, and use that instead of parsimony to explore equation space. Will naturally find equations of all complexities.
479
+ :type useFrequency: bool
480
+ :param tempdir: directory for the temporary files
481
+ :type tempdir: str/None
482
+ :param delete_tempfiles: whether to delete the temporary files after finishing
483
+ :type delete_tempfiles: bool
484
+ :param julia_project: a Julia environment location containing a Project.toml (and potentially the source code for SymbolicRegression.jl). Default gives the Python package directory, where a Project.toml file should be present from the install.
485
+ :type julia_project: str/None
486
+ :param update: Whether to automatically update Julia packages.
487
+ :type update: bool
488
+ :param temp_equation_file: Whether to put the hall of fame file in the temp directory. Deletion is then controlled with the delete_tempfiles argument.
489
+ :type temp_equation_file: bool
490
+ :param output_jax_format: Whether to create a 'jax_format' column in the output, containing jax-callable functions and the default parameters in a jax array.
491
+ :type output_jax_format: bool
492
+ :param output_torch_format: Whether to create a 'torch_format' column in the output, containing a torch module with trainable parameters.
493
+ :type output_torch_format: bool
494
+ :param tournament_selection_n: Number of expressions to consider in each tournament.
495
+ :type tournament_selection_n: int
496
+ :param tournament_selection_p: Probability of selecting the best expression in each tournament. The probability will decay as p*(1-p)^n for other expressions, sorted by loss.
497
+ :type tournament_selection_p: float
498
+ :param denoise: Whether to use a Gaussian Process to denoise the data before inputting to PySR. Can help PySR fit noisy data.
499
+ :type denoise: bool
500
+ :param precision: What precision to use for the data. By default this is 32 (float32), but you can select 64 or 16 as well.
501
+ :type precision: int
502
+ :param multithreading: Use multithreading instead of distributed backend. Default is yes. Using procs=0 will turn off both.
503
+ :type multithreading: bool
504
+ :param **kwargs: Other options passed to SymbolicRegression.Options, for example, if you modify SymbolicRegression.jl to include additional arguments.
505
+ :type **kwargs: dict
506
+ :returns: Results dataframe, giving complexity, MSE, and equations (as strings), as well as functional forms. If list, each element corresponds to a dataframe of equations for each output.
507
+ :type: pd.DataFrame/list
508
+ """
509
+ super().__init__()
510
+ self.model_selection = model_selection
511
+
512
+ if binary_operators is None:
513
+ binary_operators = "+ * - /".split(" ")
514
+ if unary_operators is None:
515
+ unary_operators = []
516
+ if extra_sympy_mappings is None:
517
+ extra_sympy_mappings = {}
518
+ if variable_names is None:
519
+ variable_names = []
520
+ if constraints is None:
521
+ constraints = {}
522
+ if multithreading is None:
523
+ # Default is multithreading=True, unless explicitly set,
524
+ # or procs is set to 0 (serial mode).
525
+ multithreading = procs != 0
526
+
527
+ buffer_available = "buffer" in sys.stdout.__dir__()
528
+
529
+ if progress is not None:
530
+ if progress and not buffer_available:
531
+ warnings.warn(
532
+ "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
533
+ )
534
+ progress = False
535
+ else:
536
+ progress = buffer_available
537
+
538
+ assert optimizer_algorithm in ["NelderMead", "BFGS"]
539
+ assert tournament_selection_n < npop
540
+
541
+ if extra_jax_mappings is not None:
542
+ for value in extra_jax_mappings.values():
543
+ if not isinstance(value, str):
544
+ raise NotImplementedError(
545
+ "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
546
+ )
547
+
548
+ if extra_torch_mappings is not None:
549
+ for value in extra_jax_mappings.values():
550
+ if not callable(value):
551
+ raise NotImplementedError(
552
+ "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
553
+ )
554
+
555
+ if maxsize > 40:
556
+ warnings.warn(
557
+ "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`."
558
+ )
559
+ elif maxsize < 7:
560
+ raise NotImplementedError("PySR requires a maxsize of at least 7")
561
+
562
+ if maxdepth is None:
563
+ maxdepth = maxsize
564
+
565
+ if isinstance(binary_operators, str):
566
+ binary_operators = [binary_operators]
567
+ if isinstance(unary_operators, str):
568
+ unary_operators = [unary_operators]
569
+
570
+ self.params = {
571
+ **dict(
572
+ weights=weights,
573
+ binary_operators=binary_operators,
574
+ unary_operators=unary_operators,
575
+ procs=procs,
576
+ loss=loss,
577
+ populations=populations,
578
+ niterations=niterations,
579
+ ncyclesperiteration=ncyclesperiteration,
580
+ alpha=alpha,
581
+ annealing=annealing,
582
+ fractionReplaced=fractionReplaced,
583
+ fractionReplacedHof=fractionReplacedHof,
584
+ npop=npop,
585
+ parsimony=float(parsimony),
586
+ migration=migration,
587
+ hofMigration=hofMigration,
588
+ shouldOptimizeConstants=shouldOptimizeConstants,
589
+ topn=topn,
590
+ weightAddNode=weightAddNode,
591
+ weightInsertNode=weightInsertNode,
592
+ weightDeleteNode=weightDeleteNode,
593
+ weightDoNothing=weightDoNothing,
594
+ weightMutateConstant=weightMutateConstant,
595
+ weightMutateOperator=weightMutateOperator,
596
+ weightRandomize=weightRandomize,
597
+ weightSimplify=weightSimplify,
598
+ perturbationFactor=perturbationFactor,
599
+ verbosity=verbosity,
600
+ progress=progress,
601
+ maxsize=maxsize,
602
+ fast_cycle=fast_cycle,
603
+ maxdepth=maxdepth,
604
+ batching=batching,
605
+ batchSize=batchSize,
606
+ select_k_features=select_k_features,
607
+ warmupMaxsizeBy=warmupMaxsizeBy,
608
+ constraints=constraints,
609
+ useFrequency=useFrequency,
610
+ tempdir=tempdir,
611
+ delete_tempfiles=delete_tempfiles,
612
+ update=update,
613
+ temp_equation_file=temp_equation_file,
614
+ optimizer_algorithm=optimizer_algorithm,
615
+ optimizer_nrestarts=optimizer_nrestarts,
616
+ optimize_probability=optimize_probability,
617
+ optimizer_iterations=optimizer_iterations,
618
+ tournament_selection_n=tournament_selection_n,
619
+ tournament_selection_p=tournament_selection_p,
620
+ denoise=denoise,
621
+ Xresampled=Xresampled,
622
+ precision=precision,
623
+ multithreading=multithreading,
624
+ ),
625
+ **kwargs,
626
+ }
627
+
628
+ # Stored equations:
629
+ self.equations = None
630
+
631
+ self.multioutput = None
632
+ self.raw_julia_output = None
633
+ self.equation_file = equation_file
634
+ self.n_features = None
635
+ self.extra_sympy_mappings = extra_sympy_mappings
636
+ self.extra_torch_mappings = extra_torch_mappings
637
+ self.extra_jax_mappings = extra_jax_mappings
638
+ self.output_jax_format = output_jax_format
639
+ self.output_torch_format = output_torch_format
640
+ self.nout = 1
641
+ self.selection = None
642
+ self.variable_names = variable_names
643
+ self.julia_project = julia_project
644
+
645
+ self.surface_parameters = [
646
+ "model_selection",
647
+ "multioutput",
648
+ "raw_julia_output",
649
+ "equation_file",
650
+ "n_features",
651
+ "extra_sympy_mappings",
652
+ "extra_torch_mappings",
653
+ "extra_jax_mappings",
654
+ "output_jax_format",
655
+ "output_torch_format",
656
+ "nout",
657
+ "selection",
658
+ "variable_names",
659
+ "julia_project"
660
+ ]
661
+
662
+ def __repr__(self):
663
+ if self.equations is None:
664
+ return "PySRRegressor.equations = None"
665
+
666
+ equations = self.equations
667
+ selected = ["" for _ in range(len(equations))]
668
+ if self.model_selection == "accuracy":
669
+ chosen_row = -1
670
+ elif self.model_selection == "best":
671
+ chosen_row = equations["score"].idxmax()
672
+ else:
673
+ raise NotImplementedError
674
+ selected[chosen_row] = ">>>>"
675
+ output = "PySRRegressor.equations = [\n"
676
+ repr_equations = pd.DataFrame(
677
+ dict(
678
+ pick=selected,
679
+ score=equations["score"],
680
+ Equation=equations["Equation"],
681
+ MSE=equations["MSE"],
682
+ Complexity=equations["Complexity"],
683
+ )
684
+ )
685
+ output += repr_equations.__repr__()
686
+ output += "\n]"
687
+ return output
688
+
689
+ def set_params(self, **params):
690
+ """Set parameters for pysr.pysr call or model_selection strategy."""
691
+ for key, value in params.items():
692
+ if key in self.surface_parameters:
693
+ self.__setattr__(key, value)
694
+ else:
695
+ self.params[key] = value
696
+
697
+ self.refresh()
698
+ return self
699
+
700
+ def get_params(self, deep=True):
701
+ del deep
702
+ return {
703
+ **self.params,
704
+ **{p: self.__getattribute__(key) for key in self.surface_parameters},
705
+ }
706
+
707
+ def get_best(self):
708
+ if self.equations is None:
709
+ return 0.0
710
+ if self.model_selection == "accuracy":
711
+ return self.equations.iloc[-1]
712
+ elif self.model_selection == "best":
713
+ return best_row(self.equations)
714
+ else:
715
+ raise NotImplementedError
716
+
717
+ def fit(self, X, y, weights=None, variable_names=None):
718
+ """Search for equations to fit the dataset.
719
+
720
+ :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
721
+ :type X: np.ndarray/pandas.DataFrame
722
+ :param y: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs). Putting in a 2D array will trigger a search for equations for each feature of y.
723
+ :type y: np.ndarray
724
+ :param weights: Optional. Same shape as y. Each element is how to weight the mean-square-error loss for that particular element of y.
725
+ :type weights: np.ndarray
726
+ :param variable_names: a list of names for the variables, other than "x0", "x1", etc.
727
+ :type variable_names: list
728
+ """
729
+ if variable_names is None:
730
+ variable_names = self.variable_names
731
+
732
+ self._run(
733
+ X=X,
734
+ y=y,
735
+ weights=weights,
736
+ variable_names=variable_names,
737
+ )
738
+
739
+ return self
740
+
741
+ def refresh(self):
742
+ # Updates self.equations with any new options passed,
743
+ # such as extra_sympy_mappings.
744
+ self.equations = self.get_hof()
745
+
746
+ def predict(self, X):
747
+ self.refresh()
748
+ np_format = self.get_best()["lambda_format"]
749
+ return np_format(X)
750
+
751
+ def sympy(self):
752
+ self.refresh()
753
+ return self.get_best()["sympy_format"]
754
+
755
+ def latex(self):
756
+ self.refresh()
757
+ return self.sympy().simplify()
758
+
759
+ def jax(self):
760
+ self.set_params(output_jax_format=True)
761
+ self.refresh()
762
+ return self.get_best()["jax_format"]
763
+
764
+ def pytorch(self):
765
+ self.set_params(output_torch_format=True)
766
+ self.refresh()
767
+ return self.get_best()["torch_format"]
768
+
769
+ def _run(self, X, y, weights, variable_names):
770
+ global already_ran
771
+ global Main
772
+
773
+ for key in self.surface_parameters:
774
+ if key in self.params:
775
+ raise ValueError(
776
+ f"{key} is a surface parameter, and cannot be in self.params"
777
+ )
778
+
779
+ multithreading = self.params["multithreading"]
780
+ procs = self.params["procs"]
781
+ binary_operators = self.params["binary_operators"]
782
+ unary_operators = self.params["unary_operators"]
783
+ batching = self.params["batching"]
784
+ maxsize = self.params["maxsize"]
785
+ select_k_features = self.params["select_k_features"]
786
+ Xresampled = self.params["Xresampled"]
787
+ denoise = self.params["denoise"]
788
+ constraints = self.params["constraints"]
789
+ update = self.params["update"]
790
+ loss = self.params["loss"]
791
+ weightMutateConstant = self.params["weightMutateConstant"]
792
+ weightMutateOperator = self.params["weightMutateOperator"]
793
+ weightAddNode = self.params["weightAddNode"]
794
+ weightInsertNode = self.params["weightInsertNode"]
795
+ weightDeleteNode = self.params["weightDeleteNode"]
796
+ weightSimplify = self.params["weightSimplify"]
797
+ weightRandomize = self.params["weightRandomize"]
798
+ weightDoNothing = self.params["weightDoNothing"]
799
+
800
+ if Main is None:
801
+ if multithreading:
802
+ os.environ["JULIA_NUM_THREADS"] = str(procs)
803
+
804
+ Main = init_julia()
805
+
806
+ if isinstance(X, pd.DataFrame):
807
+ if variable_names is not None:
808
+ warnings.warn("Resetting variable_names from X.columns")
809
+
810
+ variable_names = list(X.columns)
811
+ X = np.array(X)
812
+
813
+ if len(X.shape) == 1:
814
+ X = X[:, None]
815
+
816
+ assert not isinstance(y, pd.DataFrame)
817
+
818
+ if len(variable_names) == 0:
819
+ variable_names = [f"x{i}" for i in range(X.shape[1])]
820
+
821
+ use_custom_variable_names = len(variable_names) != 0
822
+ # TODO: this is always true.
823
+
824
+ _check_assertions(
825
+ X,
826
+ binary_operators,
827
+ unary_operators,
828
+ use_custom_variable_names,
829
+ variable_names,
830
+ weights,
831
+ y,
832
+ )
833
+
834
+ self.n_features = X.shape[1]
835
+
836
+ if len(X) > 10000 and not batching:
837
+ warnings.warn(
838
+ "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (https://pysr.readthedocs.io/en/latest/docs/options/#batching). You should also reconsider if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form with symbolic regression. More datapoints will lower the search speed."
839
+ )
840
+
841
+ X, selection = _handle_feature_selection(
842
+ X, select_k_features, y, variable_names
843
+ )
844
+
845
+ if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
846
+ self.multioutput = False
847
+ nout = 1
848
+ y = y.reshape(-1)
849
+ elif len(y.shape) == 2:
850
+ self.multioutput = True
851
+ nout = y.shape[1]
852
+ else:
853
+ raise NotImplementedError("y shape not supported!")
854
+
855
+ if denoise:
856
+ if weights is not None:
857
+ raise NotImplementedError(
858
+ "No weights for denoising - the weights are learned."
859
+ )
860
+ if Xresampled is not None:
861
+ # Select among only the selected features:
862
+ if isinstance(Xresampled, pd.DataFrame):
863
+ # Handle Xresampled is pandas dataframe
864
+ if selection is not None:
865
+ Xresampled = Xresampled[[variable_names[i] for i in selection]]
866
+ else:
867
+ Xresampled = Xresampled[variable_names]
868
+ Xresampled = np.array(Xresampled)
869
+ else:
870
+ if selection is not None:
871
+ Xresampled = Xresampled[:, selection]
872
+ if self.multioutput:
873
+ y = np.stack(
874
+ [
875
+ _denoise(X, y[:, i], Xresampled=Xresampled)[1]
876
+ for i in range(nout)
877
+ ],
878
+ axis=1,
879
+ )
880
+ if Xresampled is not None:
881
+ X = Xresampled
882
+ else:
883
+ X, y = _denoise(X, y, Xresampled=Xresampled)
884
+
885
+ self.julia_project = _get_julia_project(self.julia_project)
886
+
887
+ tmpdir = Path(tempfile.mkdtemp(dir=self.params["tempdir"]))
888
+
889
+ if self.params["temp_equation_file"]:
890
+ self.equation_file = tmpdir / "hall_of_fame.csv"
891
+ elif self.equation_file is None:
892
+ date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
893
+ self.equation_file = "hall_of_fame_" + date_time + ".csv"
894
+
895
+ _create_inline_operators(
896
+ binary_operators=binary_operators, unary_operators=unary_operators
897
+ )
898
+ _handle_constraints(
899
+ binary_operators=binary_operators,
900
+ unary_operators=unary_operators,
901
+ constraints=constraints,
902
+ )
903
+
904
+ una_constraints = [constraints[op] for op in unary_operators]
905
+ bin_constraints = [constraints[op] for op in binary_operators]
906
+
907
+ try:
908
+ # TODO: is this needed since Julia now prints directly to stdout?
909
+ term_width = shutil.get_terminal_size().columns
910
+ except:
911
+ _, term_width = subprocess.check_output(["stty", "size"]).split()
912
+
913
+ if not already_ran:
914
+ from julia import Pkg
915
+
916
+ Pkg.activate(f"{_escape_filename(self.julia_project)}")
917
+ try:
918
+ if update:
919
+ Pkg.resolve()
920
+ Pkg.instantiate()
921
+ else:
922
+ Pkg.instantiate()
923
+ except RuntimeError as e:
924
+ raise ImportError(
925
+ f"""
926
+ Required dependencies are not installed or built. Run the following code in the Python REPL:
927
+
928
+ >>> import pysr
929
+ >>> pysr.install()
930
+
931
+ Tried to activate project {self.julia_project} but failed."""
932
+ ) from e
933
+ Main.eval("using SymbolicRegression")
934
+
935
+ Main.plus = Main.eval("(+)")
936
+ Main.sub = Main.eval("(-)")
937
+ Main.mult = Main.eval("(*)")
938
+ Main.pow = Main.eval("(^)")
939
+ Main.div = Main.eval("(/)")
940
+
941
+ Main.custom_loss = Main.eval(loss)
942
+
943
+ mutationWeights = [
944
+ float(weightMutateConstant),
945
+ float(weightMutateOperator),
946
+ float(weightAddNode),
947
+ float(weightInsertNode),
948
+ float(weightDeleteNode),
949
+ float(weightSimplify),
950
+ float(weightRandomize),
951
+ float(weightDoNothing),
952
+ ]
953
+
954
+ options = Main.Options(
955
+ binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
956
+ unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
957
+ bin_constraints=bin_constraints,
958
+ una_constraints=una_constraints,
959
+ loss=Main.custom_loss,
960
+ maxsize=int(maxsize),
961
+ hofFile=_escape_filename(self.equation_file),
962
+ npopulations=int(self.params["populations"]),
963
+ batching=batching,
964
+ batchSize=int(
965
+ min([self.params["batchSize"], len(X)]) if batching else len(X)
966
+ ),
967
+ mutationWeights=mutationWeights,
968
+ terminal_width=int(term_width),
969
+ probPickFirst=self.params["tournament_selection_p"],
970
+ ns=self.params["tournament_selection_n"],
971
+ # These have the same name:
972
+ parsimony=self.params["parsimony"],
973
+ alpha=self.params["alpha"],
974
+ maxdepth=self.params["maxdepth"],
975
+ fast_cycle=self.params["fast_cycle"],
976
+ migration=self.params["migration"],
977
+ hofMigration=self.params["hofMigration"],
978
+ fractionReplacedHof=self.params["fractionReplacedHof"],
979
+ shouldOptimizeConstants=self.params["shouldOptimizeConstants"],
980
+ warmupMaxsizeBy=self.params["warmupMaxsizeBy"],
981
+ useFrequency=self.params["useFrequency"],
982
+ npop=self.params["npop"],
983
+ ncyclesperiteration=self.params["ncyclesperiteration"],
984
+ fractionReplaced=self.params["fractionReplaced"],
985
+ topn=self.params["topn"],
986
+ verbosity=self.params["verbosity"],
987
+ optimizer_algorithm=self.params["optimizer_algorithm"],
988
+ optimizer_nrestarts=self.params["optimizer_nrestarts"],
989
+ optimize_probability=self.params["optimize_probability"],
990
+ optimizer_iterations=self.params["optimizer_iterations"],
991
+ perturbationFactor=self.params["perturbationFactor"],
992
+ annealing=self.params["annealing"],
993
+ )
994
+
995
+ np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[
996
+ self.params["precision"]
997
+ ]
998
+
999
+ Main.X = np.array(X, dtype=np_dtype).T
1000
+ if len(y.shape) == 1:
1001
+ Main.y = np.array(y, dtype=np_dtype)
1002
+ else:
1003
+ Main.y = np.array(y, dtype=np_dtype).T
1004
+ if weights is not None:
1005
+ if len(weights.shape) == 1:
1006
+ Main.weights = np.array(weights, dtype=np_dtype)
1007
+ else:
1008
+ Main.weights = np.array(weights, dtype=np_dtype).T
1009
+ else:
1010
+ Main.weights = None
1011
+
1012
+ cprocs = 0 if multithreading else procs
1013
+
1014
+ self.raw_julia_output = Main.EquationSearch(
1015
+ Main.X,
1016
+ Main.y,
1017
+ weights=Main.weights,
1018
+ niterations=int(self.params["niterations"]),
1019
+ varMap=(
1020
+ variable_names
1021
+ if selection is None
1022
+ else [variable_names[i] for i in selection]
1023
+ ),
1024
+ options=options,
1025
+ numprocs=int(cprocs),
1026
+ multithreading=bool(multithreading),
1027
+ )
1028
+
1029
+ self.variable_names = variable_names
1030
+ self.selection = selection
1031
+
1032
+ # Not in params:
1033
+ # selection, variable_names, multioutput
1034
+
1035
+ self.equations = self.get_hof()
1036
+
1037
+ if self.params["delete_tempfiles"]:
1038
+ shutil.rmtree(tmpdir)
1039
+
1040
+ already_ran = True
1041
+
1042
+ def get_hof(self):
1043
+ """Get the equations from a hall of fame file. If no arguments
1044
+ entered, the ones used previously from a call to PySR will be used."""
1045
+
1046
+ try:
1047
+ if self.multioutput:
1048
+ all_outputs = [
1049
+ pd.read_csv(
1050
+ str(self.equation_file) + f".out{i}" + ".bkup",
1051
+ sep="|",
1052
+ )
1053
+ for i in range(1, self.nout + 1)
1054
+ ]
1055
+ else:
1056
+ all_outputs = [pd.read_csv(str(self.equation_file) + ".bkup", sep="|")]
1057
+ except FileNotFoundError:
1058
+ raise RuntimeError(
1059
+ "Couldn't find equation file! The equation search likely exited before a single iteration completed."
1060
+ )
1061
+
1062
+ ret_outputs = []
1063
+
1064
+ for output in all_outputs:
1065
+
1066
+ scores = []
1067
+ lastMSE = None
1068
+ lastComplexity = 0
1069
+ sympy_format = []
1070
+ lambda_format = []
1071
+ if self.output_jax_format:
1072
+ jax_format = []
1073
+ if self.output_torch_format:
1074
+ torch_format = []
1075
+ use_custom_variable_names = len(self.variable_names) != 0
1076
+ local_sympy_mappings = {
1077
+ **self.extra_sympy_mappings,
1078
+ **sympy_mappings,
1079
+ }
1080
+
1081
+ if use_custom_variable_names:
1082
+ sympy_symbols = [
1083
+ sympy.Symbol(self.variable_names[i]) for i in range(self.n_features)
1084
+ ]
1085
+ else:
1086
+ sympy_symbols = [
1087
+ sympy.Symbol("x%d" % i) for i in range(self.n_features)
1088
+ ]
1089
+
1090
+ for _, eqn_row in output.iterrows():
1091
+ eqn = sympify(eqn_row["Equation"], locals=local_sympy_mappings)
1092
+ sympy_format.append(eqn)
1093
+
1094
+ # Numpy:
1095
+ lambda_format.append(
1096
+ CallableEquation(
1097
+ sympy_symbols, eqn, self.selection, self.variable_names
1098
+ )
1099
+ )
1100
+
1101
+ # JAX:
1102
+ if self.output_jax_format:
1103
+ from .export_jax import sympy2jax
1104
+
1105
+ func, params = sympy2jax(
1106
+ eqn,
1107
+ sympy_symbols,
1108
+ selection=self.selection,
1109
+ extra_jax_mappings=self.extra_jax_mappings,
1110
+ )
1111
+ jax_format.append({"callable": func, "parameters": params})
1112
+
1113
+ # Torch:
1114
+ if self.output_torch_format:
1115
+ from .export_torch import sympy2torch
1116
+
1117
+ module = sympy2torch(
1118
+ eqn,
1119
+ sympy_symbols,
1120
+ selection=self.selection,
1121
+ extra_torch_mappings=self.extra_torch_mappings,
1122
+ )
1123
+ torch_format.append(module)
1124
+
1125
+ curMSE = eqn_row["MSE"]
1126
+ curComplexity = eqn_row["Complexity"]
1127
+
1128
+ if lastMSE is None:
1129
+ cur_score = 0.0
1130
+ else:
1131
+ if curMSE > 0.0:
1132
+ cur_score = -np.log(curMSE / lastMSE) / (
1133
+ curComplexity - lastComplexity
1134
+ )
1135
+ else:
1136
+ cur_score = np.inf
1137
+
1138
+ scores.append(cur_score)
1139
+ lastMSE = curMSE
1140
+ lastComplexity = curComplexity
1141
+
1142
+ output["score"] = np.array(scores)
1143
+ output["sympy_format"] = sympy_format
1144
+ output["lambda_format"] = lambda_format
1145
+ output_cols = [
1146
+ "Complexity",
1147
+ "MSE",
1148
+ "score",
1149
+ "Equation",
1150
+ "sympy_format",
1151
+ "lambda_format",
1152
+ ]
1153
+ if self.output_jax_format:
1154
+ output_cols += ["jax_format"]
1155
+ output["jax_format"] = jax_format
1156
+ if self.output_torch_format:
1157
+ output_cols += ["torch_format"]
1158
+ output["torch_format"] = torch_format
1159
+
1160
+ ret_outputs.append(output[output_cols])
1161
+
1162
+ if self.multioutput:
1163
+ return ret_outputs
1164
+ return ret_outputs[0]