MilesCranmer commited on
Commit
b5d0afb
2 Parent(s): 2b07f83 7d4300a

Merge pull request #49 from MilesCranmer/deepsource-transform-9e0a82d1

Browse files
benchmarks/hyperparamopt.py CHANGED
@@ -10,6 +10,7 @@ import time
10
  import contextlib
11
  import numpy as np
12
 
 
13
  @contextlib.contextmanager
14
  def temp_seed(seed):
15
  state = np.random.get_state()
@@ -20,11 +21,12 @@ def temp_seed(seed):
20
  np.random.set_state(state)
21
 
22
 
23
- #Change the following code to your file
24
  ################################################################################
25
- TRIALS_FOLDER = 'trials'
26
  NUMBER_TRIALS_PER_RUN = 1
27
 
 
28
  def run_trial(args):
29
  """Evaluate the model loss using the hyperparams in args
30
 
@@ -34,29 +36,29 @@ def run_trial(args):
34
  """
35
 
36
  print("Running on", args)
37
- args['niterations'] = 100
38
- args['npop'] = 100
39
- args['ncyclesperiteration'] = 1000
40
- args['topn'] = 10
41
- args['parsimony'] = 0.0
42
- args['useFrequency'] = True
43
- args['annealing'] = True
44
-
45
- if args['npop'] < 20 or args['ncyclesperiteration'] < 3:
46
  print("Bad parameters")
47
- return {'status': 'ok', 'loss': np.inf}
48
 
49
- args['weightDoNothing'] = 1.0
50
  ntrials = 3
51
 
52
  with temp_seed(0):
53
- X = np.random.randn(100, 10)*3
54
 
55
  eval_str = [
56
- "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5",
57
- "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)",
58
- "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)",
59
- "X[:, 0] * np.sin(2*np.pi * (X[:, 1] * X[:, 2] - X[:, 3] / X[:, 4])) + 3.0"
60
  ]
61
 
62
  print(f"Starting", str(args))
@@ -67,51 +69,50 @@ def run_trial(args):
67
  for j in range(ntrials):
68
  print(f"Starting trial {j}")
69
  y = eval(eval_str[i])
70
- trial = pysr.pysr(X, y,
 
 
71
  procs=4,
72
  populations=20,
73
  binary_operators=["plus", "mult", "pow", "div"],
74
  unary_operators=["cos", "exp", "sin", "logm", "abs"],
75
  maxsize=25,
76
- constraints={'pow': (-1, 1)},
77
- **args)
78
- if len(trial) == 0: raise ValueError
 
 
79
  trials.append(
80
- np.min(trial['MSE'])**0.5 / np.std(eval(eval_str[i-1]))
81
  )
82
  print(f"Test {i} trial {j} with", str(args), f"got {trials[-1]}")
83
 
84
  except ValueError:
85
  print(f"Broken", str(args))
86
- return {
87
- 'status': 'ok', # or 'fail' if nan loss
88
- 'loss': np.inf
89
- }
90
  loss = np.average(trials)
91
  print(f"Finished with {loss}", str(args))
92
 
93
- return {
94
- 'status': 'ok', # or 'fail' if nan loss
95
- 'loss': loss
96
- }
97
 
98
 
99
  space = {
100
- 'alpha': hp.lognormal('alpha', np.log(10.0), 1.0),
101
- 'fractionReplacedHof': hp.lognormal('fractionReplacedHof', np.log(0.1), 1.0),
102
- 'fractionReplaced': hp.lognormal('fractionReplaced', np.log(0.1), 1.0),
103
- 'perturbationFactor': hp.lognormal('perturbationFactor', np.log(1.0), 1.0),
104
- 'weightMutateConstant': hp.lognormal('weightMutateConstant', np.log(4.0), 1.0),
105
- 'weightMutateOperator': hp.lognormal('weightMutateOperator', np.log(0.5), 1.0),
106
- 'weightAddNode': hp.lognormal('weightAddNode', np.log(0.5), 1.0),
107
- 'weightInsertNode': hp.lognormal('weightInsertNode', np.log(0.5), 1.0),
108
- 'weightDeleteNode': hp.lognormal('weightDeleteNode', np.log(0.5), 1.0),
109
- 'weightSimplify': hp.lognormal('weightSimplify', np.log(0.05), 1.0),
110
- 'weightRandomize': hp.lognormal('weightRandomize', np.log(0.25), 1.0),
111
  }
112
 
113
  ################################################################################
114
 
 
115
  def merge_trials(trials1, trials2_slice):
116
  """Merge two hyperopt trials objects
117
 
@@ -123,24 +124,23 @@ def merge_trials(trials1, trials2_slice):
123
  """
124
  max_tid = 0
125
  if len(trials1.trials) > 0:
126
- max_tid = max([trial['tid'] for trial in trials1.trials])
127
 
128
  for trial in trials2_slice:
129
- tid = trial['tid'] + max_tid + 1
130
  hyperopt_trial = Trials().new_trial_docs(
131
- tids=[None],
132
- specs=[None],
133
- results=[None],
134
- miscs=[None])
135
  hyperopt_trial[0] = trial
136
- hyperopt_trial[0]['tid'] = tid
137
- hyperopt_trial[0]['misc']['tid'] = tid
138
- for key in hyperopt_trial[0]['misc']['idxs'].keys():
139
- hyperopt_trial[0]['misc']['idxs'][key] = [tid]
140
- trials1.insert_trial_docs(hyperopt_trial)
141
  trials1.refresh()
142
  return trials1
143
 
 
144
  loaded_fnames = []
145
  trials = None
146
  # Run new hyperparameter trials until killed
@@ -149,15 +149,16 @@ while True:
149
 
150
  # Load up all runs:
151
  import glob
152
- path = TRIALS_FOLDER + '/*.pkl'
 
153
  for fname in glob.glob(path):
154
  if fname in loaded_fnames:
155
  continue
156
 
157
- trials_obj = pkl.load(open(fname, 'rb'))
158
- n_trials = trials_obj['n']
159
- trials_obj = trials_obj['trials']
160
- if len(loaded_fnames) == 0:
161
  trials = trials_obj
162
  else:
163
  print("Merging trials")
@@ -171,23 +172,29 @@ while True:
171
 
172
  n = NUMBER_TRIALS_PER_RUN
173
  try:
174
- best = fmin(run_trial,
 
175
  space=space,
176
  algo=tpe.suggest,
177
  max_evals=n + len(trials.trials),
178
  trials=trials,
179
  verbose=1,
180
- rstate=np.random.RandomState(np.random.randint(1,10**6))
181
- )
182
  except hyperopt.exceptions.AllTrialsFailed:
183
  continue
184
 
185
- print('current best', best)
186
  hyperopt_trial = Trials()
187
 
188
  # Merge with empty trials dataset:
189
  save_trials = merge_trials(hyperopt_trial, trials.trials[-n:])
190
- new_fname = TRIALS_FOLDER + '/' + str(np.random.randint(0, sys.maxsize)) + str(time.time()) + '.pkl'
191
- pkl.dump({'trials': save_trials, 'n': n}, open(new_fname, 'wb'))
 
 
 
 
 
 
192
  loaded_fnames.append(new_fname)
193
-
 
10
  import contextlib
11
  import numpy as np
12
 
13
+
14
  @contextlib.contextmanager
15
  def temp_seed(seed):
16
  state = np.random.get_state()
 
21
  np.random.set_state(state)
22
 
23
 
24
+ # Change the following code to your file
25
  ################################################################################
26
+ TRIALS_FOLDER = "trials"
27
  NUMBER_TRIALS_PER_RUN = 1
28
 
29
+
30
  def run_trial(args):
31
  """Evaluate the model loss using the hyperparams in args
32
 
 
36
  """
37
 
38
  print("Running on", args)
39
+ args["niterations"] = 100
40
+ args["npop"] = 100
41
+ args["ncyclesperiteration"] = 1000
42
+ args["topn"] = 10
43
+ args["parsimony"] = 0.0
44
+ args["useFrequency"] = True
45
+ args["annealing"] = True
46
+
47
+ if args["npop"] < 20 or args["ncyclesperiteration"] < 3:
48
  print("Bad parameters")
49
+ return {"status": "ok", "loss": np.inf}
50
 
51
+ args["weightDoNothing"] = 1.0
52
  ntrials = 3
53
 
54
  with temp_seed(0):
55
+ X = np.random.randn(100, 10) * 3
56
 
57
  eval_str = [
58
+ "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5",
59
+ "np.exp(X[:, 0]/2) + 12.0 + np.log(np.abs(X[:, 0])*10 + 1)",
60
+ "(np.exp(X[:, 3]) + 3)/(np.abs(X[:, 1]) + np.cos(X[:, 0]) + 1.1)",
61
+ "X[:, 0] * np.sin(2*np.pi * (X[:, 1] * X[:, 2] - X[:, 3] / X[:, 4])) + 3.0",
62
  ]
63
 
64
  print(f"Starting", str(args))
 
69
  for j in range(ntrials):
70
  print(f"Starting trial {j}")
71
  y = eval(eval_str[i])
72
+ trial = pysr.pysr(
73
+ X,
74
+ y,
75
  procs=4,
76
  populations=20,
77
  binary_operators=["plus", "mult", "pow", "div"],
78
  unary_operators=["cos", "exp", "sin", "logm", "abs"],
79
  maxsize=25,
80
+ constraints={"pow": (-1, 1)},
81
+ **args,
82
+ )
83
+ if len(trial) == 0:
84
+ raise ValueError
85
  trials.append(
86
+ np.min(trial["MSE"]) ** 0.5 / np.std(eval(eval_str[i - 1]))
87
  )
88
  print(f"Test {i} trial {j} with", str(args), f"got {trials[-1]}")
89
 
90
  except ValueError:
91
  print(f"Broken", str(args))
92
+ return {"status": "ok", "loss": np.inf} # or 'fail' if nan loss
 
 
 
93
  loss = np.average(trials)
94
  print(f"Finished with {loss}", str(args))
95
 
96
+ return {"status": "ok", "loss": loss} # or 'fail' if nan loss
 
 
 
97
 
98
 
99
  space = {
100
+ "alpha": hp.lognormal("alpha", np.log(10.0), 1.0),
101
+ "fractionReplacedHof": hp.lognormal("fractionReplacedHof", np.log(0.1), 1.0),
102
+ "fractionReplaced": hp.lognormal("fractionReplaced", np.log(0.1), 1.0),
103
+ "perturbationFactor": hp.lognormal("perturbationFactor", np.log(1.0), 1.0),
104
+ "weightMutateConstant": hp.lognormal("weightMutateConstant", np.log(4.0), 1.0),
105
+ "weightMutateOperator": hp.lognormal("weightMutateOperator", np.log(0.5), 1.0),
106
+ "weightAddNode": hp.lognormal("weightAddNode", np.log(0.5), 1.0),
107
+ "weightInsertNode": hp.lognormal("weightInsertNode", np.log(0.5), 1.0),
108
+ "weightDeleteNode": hp.lognormal("weightDeleteNode", np.log(0.5), 1.0),
109
+ "weightSimplify": hp.lognormal("weightSimplify", np.log(0.05), 1.0),
110
+ "weightRandomize": hp.lognormal("weightRandomize", np.log(0.25), 1.0),
111
  }
112
 
113
  ################################################################################
114
 
115
+
116
  def merge_trials(trials1, trials2_slice):
117
  """Merge two hyperopt trials objects
118
 
 
124
  """
125
  max_tid = 0
126
  if len(trials1.trials) > 0:
127
+ max_tid = max([trial["tid"] for trial in trials1.trials])
128
 
129
  for trial in trials2_slice:
130
+ tid = trial["tid"] + max_tid + 1
131
  hyperopt_trial = Trials().new_trial_docs(
132
+ tids=[None], specs=[None], results=[None], miscs=[None]
133
+ )
 
 
134
  hyperopt_trial[0] = trial
135
+ hyperopt_trial[0]["tid"] = tid
136
+ hyperopt_trial[0]["misc"]["tid"] = tid
137
+ for key in hyperopt_trial[0]["misc"]["idxs"].keys():
138
+ hyperopt_trial[0]["misc"]["idxs"][key] = [tid]
139
+ trials1.insert_trial_docs(hyperopt_trial)
140
  trials1.refresh()
141
  return trials1
142
 
143
+
144
  loaded_fnames = []
145
  trials = None
146
  # Run new hyperparameter trials until killed
 
149
 
150
  # Load up all runs:
151
  import glob
152
+
153
+ path = TRIALS_FOLDER + "/*.pkl"
154
  for fname in glob.glob(path):
155
  if fname in loaded_fnames:
156
  continue
157
 
158
+ trials_obj = pkl.load(open(fname, "rb"))
159
+ n_trials = trials_obj["n"]
160
+ trials_obj = trials_obj["trials"]
161
+ if len(loaded_fnames) == 0:
162
  trials = trials_obj
163
  else:
164
  print("Merging trials")
 
172
 
173
  n = NUMBER_TRIALS_PER_RUN
174
  try:
175
+ best = fmin(
176
+ run_trial,
177
  space=space,
178
  algo=tpe.suggest,
179
  max_evals=n + len(trials.trials),
180
  trials=trials,
181
  verbose=1,
182
+ rstate=np.random.RandomState(np.random.randint(1, 10 ** 6)),
183
+ )
184
  except hyperopt.exceptions.AllTrialsFailed:
185
  continue
186
 
187
+ print("current best", best)
188
  hyperopt_trial = Trials()
189
 
190
  # Merge with empty trials dataset:
191
  save_trials = merge_trials(hyperopt_trial, trials.trials[-n:])
192
+ new_fname = (
193
+ TRIALS_FOLDER
194
+ + "/"
195
+ + str(np.random.randint(0, sys.maxsize))
196
+ + str(time.time())
197
+ + ".pkl"
198
+ )
199
+ pkl.dump({"trials": save_trials, "n": n}, open(new_fname, "wb"))
200
  loaded_fnames.append(new_fname)
 
example.py CHANGED
@@ -2,18 +2,25 @@ import numpy as np
2
  from pysr import pysr, best
3
 
4
  # Dataset
5
- X = 2*np.random.randn(100, 5)
6
- y = 2*np.cos(X[:, 3]) + X[:, 0]**2 - 2
7
 
8
  # Learn equations
9
- equations = pysr(X, y, niterations=5,
 
 
 
10
  binary_operators=["plus", "mult"],
11
  unary_operators=[
12
- "cos", "exp", "sin", #Pre-defined library of operators (see https://pysr.readthedocs.io/en/latest/docs/operators/)
13
- "inv(x) = 1/x"],
14
- loss='loss(x, y) = abs(x - y)', # Custom loss function
15
- julia_project="../SymbolicRegression.jl") # Define your own operator! (Julia syntax)
 
 
 
 
16
 
17
- ...# (you can use ctl-c to exit early)
18
 
19
  print(best(equations))
 
2
  from pysr import pysr, best
3
 
4
  # Dataset
5
+ X = 2 * np.random.randn(100, 5)
6
+ y = 2 * np.cos(X[:, 3]) + X[:, 0] ** 2 - 2
7
 
8
  # Learn equations
9
+ equations = pysr(
10
+ X,
11
+ y,
12
+ niterations=5,
13
  binary_operators=["plus", "mult"],
14
  unary_operators=[
15
+ "cos",
16
+ "exp",
17
+ "sin", # Pre-defined library of operators (see https://pysr.readthedocs.io/en/latest/docs/operators/)
18
+ "inv(x) = 1/x",
19
+ ],
20
+ loss="loss(x, y) = abs(x - y)", # Custom loss function
21
+ julia_project="../SymbolicRegression.jl",
22
+ ) # Define your own operator! (Julia syntax)
23
 
24
+ ... # (you can use ctl-c to exit early)
25
 
26
  print(best(equations))
pysr/export_jax.py CHANGED
@@ -58,14 +58,16 @@ def sympy2jaxtext(expr, parameters, symbols_in):
58
  elif issubclass(expr.func, sympy.Integer):
59
  return f"{int(expr)}"
60
  elif issubclass(expr.func, sympy.Symbol):
61
- return f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
 
 
62
  else:
63
  _func = _jnp_func_lookup[expr.func]
64
  args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
65
  if _func == MUL:
66
- return ' * '.join(['(' + arg + ')' for arg in args])
67
  elif _func == ADD:
68
- return ' + '.join(['(' + arg + ')' for arg in args])
69
  else:
70
  return f'{_func}({", ".join(args)})'
71
 
@@ -75,6 +77,7 @@ jax = None
75
  jnp = None
76
  jsp = None
77
 
 
78
  def _initialize_jax():
79
  global jax_initialized
80
  global jax
@@ -85,6 +88,7 @@ def _initialize_jax():
85
  import jax as _jax
86
  from jax import numpy as _jnp
87
  from jax.scipy import special as _jsp
 
88
  jax = _jax
89
  jnp = _jnp
90
  jsp = _jsp
@@ -169,7 +173,7 @@ def sympy2jax(expression, symbols_in, selection=None):
169
 
170
  parameters = []
171
  functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
172
- hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
173
  text = f"def {hash_string}(X, parameters):\n"
174
  if selection is not None:
175
  # Impose the feature selection:
 
58
  elif issubclass(expr.func, sympy.Integer):
59
  return f"{int(expr)}"
60
  elif issubclass(expr.func, sympy.Symbol):
61
+ return (
62
+ f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
+ )
64
  else:
65
  _func = _jnp_func_lookup[expr.func]
66
  args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
67
  if _func == MUL:
68
+ return " * ".join(["(" + arg + ")" for arg in args])
69
  elif _func == ADD:
70
+ return " + ".join(["(" + arg + ")" for arg in args])
71
  else:
72
  return f'{_func}({", ".join(args)})'
73
 
 
77
  jnp = None
78
  jsp = None
79
 
80
+
81
  def _initialize_jax():
82
  global jax_initialized
83
  global jax
 
88
  import jax as _jax
89
  from jax import numpy as _jnp
90
  from jax.scipy import special as _jsp
91
+
92
  jax = _jax
93
  jnp = _jnp
94
  jsp = _jsp
 
173
 
174
  parameters = []
175
  functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
176
+ hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
177
  text = f"def {hash_string}(X, parameters):\n"
178
  if selection is not None:
179
  # Impose the feature selection:
pysr/export_torch.py CHANGED
@@ -7,17 +7,21 @@ import collections as co
7
  import functools as ft
8
  import sympy
9
 
 
10
  def _reduce(fn):
11
  def fn_(*args):
12
  return ft.reduce(fn, args)
 
13
  return fn_
14
 
 
15
  torch_initialized = False
16
  torch = None
17
  _global_func_lookup = None
18
  _Node = None
19
  SingleSymPyModule = None
20
 
 
21
  def _initialize_torch():
22
  global torch_initialized
23
  global torch
@@ -29,6 +33,7 @@ def _initialize_torch():
29
  # but still allow this module to be loaded in __init__
30
  if not torch_initialized:
31
  import torch as _torch
 
32
  torch = _torch
33
 
34
  _global_func_lookup = {
@@ -85,6 +90,7 @@ def _initialize_torch():
85
 
86
  class _Node(torch.nn.Module):
87
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
 
88
  def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
89
  super().__init__(**kwargs)
90
 
@@ -95,9 +101,13 @@ def _initialize_torch():
95
  self._torch_func = lambda: self._value
96
  self._args = ()
97
  elif issubclass(expr.func, sympy.UnevaluatedExpr):
98
- if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
99
- raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
100
- self.register_buffer('_value', torch.tensor(float(expr.args[0])))
 
 
 
 
101
  self._torch_func = lambda: self._value
102
  self._args = ()
103
  elif issubclass(expr.func, sympy.Integer):
@@ -117,7 +127,12 @@ def _initialize_torch():
117
  try:
118
  arg_ = _memodict[arg]
119
  except KeyError:
120
- arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
 
 
 
 
 
121
  _memodict[arg] = arg_
122
  args.append(arg_)
123
  self._args = torch.nn.ModuleList(args)
@@ -133,19 +148,22 @@ def _initialize_torch():
133
  args.append(arg_)
134
  return self._torch_func(*args)
135
 
136
-
137
  class SingleSymPyModule(torch.nn.Module):
138
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
139
- def __init__(self, expression, symbols_in,
140
- selection=None, extra_funcs=None, **kwargs):
 
 
141
  super().__init__(**kwargs)
142
-
143
  if extra_funcs is None:
144
  extra_funcs = {}
145
  _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
146
 
147
  _memodict = {}
148
- self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
 
 
149
  self._expression_string = str(expression)
150
  self._selection = selection
151
  self.symbols_in = [str(symbol) for symbol in symbols_in]
@@ -156,13 +174,11 @@ def _initialize_torch():
156
  def forward(self, X):
157
  if self._selection is not None:
158
  X = X[:, self._selection]
159
- symbols = {symbol: X[:, i]
160
- for i, symbol in enumerate(self.symbols_in)}
161
  return self._node(symbols)
162
 
163
 
164
- def sympy2torch(expression, symbols_in,
165
- selection=None, extra_torch_mappings=None):
166
  """Returns a module for a given sympy expression with trainable parameters;
167
 
168
  This function will assume the input to the module is a matrix X, where
@@ -172,6 +188,6 @@ def sympy2torch(expression, symbols_in,
172
 
173
  _initialize_torch()
174
 
175
- return SingleSymPyModule(expression, symbols_in,
176
- selection=selection,
177
- extra_funcs=extra_torch_mappings)
 
7
  import functools as ft
8
  import sympy
9
 
10
+
11
  def _reduce(fn):
12
  def fn_(*args):
13
  return ft.reduce(fn, args)
14
+
15
  return fn_
16
 
17
+
18
  torch_initialized = False
19
  torch = None
20
  _global_func_lookup = None
21
  _Node = None
22
  SingleSymPyModule = None
23
 
24
+
25
  def _initialize_torch():
26
  global torch_initialized
27
  global torch
 
33
  # but still allow this module to be loaded in __init__
34
  if not torch_initialized:
35
  import torch as _torch
36
+
37
  torch = _torch
38
 
39
  _global_func_lookup = {
 
90
 
91
  class _Node(torch.nn.Module):
92
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
93
+
94
  def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
95
  super().__init__(**kwargs)
96
 
 
101
  self._torch_func = lambda: self._value
102
  self._args = ()
103
  elif issubclass(expr.func, sympy.UnevaluatedExpr):
104
+ if len(expr.args) != 1 or not issubclass(
105
+ expr.args[0].func, sympy.Float
106
+ ):
107
+ raise ValueError(
108
+ "UnevaluatedExpr should only be used to wrap floats."
109
+ )
110
+ self.register_buffer("_value", torch.tensor(float(expr.args[0])))
111
  self._torch_func = lambda: self._value
112
  self._args = ()
113
  elif issubclass(expr.func, sympy.Integer):
 
127
  try:
128
  arg_ = _memodict[arg]
129
  except KeyError:
130
+ arg_ = type(self)(
131
+ expr=arg,
132
+ _memodict=_memodict,
133
+ _func_lookup=_func_lookup,
134
+ **kwargs,
135
+ )
136
  _memodict[arg] = arg_
137
  args.append(arg_)
138
  self._args = torch.nn.ModuleList(args)
 
148
  args.append(arg_)
149
  return self._torch_func(*args)
150
 
 
151
  class SingleSymPyModule(torch.nn.Module):
152
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
153
+
154
+ def __init__(
155
+ self, expression, symbols_in, selection=None, extra_funcs=None, **kwargs
156
+ ):
157
  super().__init__(**kwargs)
158
+
159
  if extra_funcs is None:
160
  extra_funcs = {}
161
  _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
162
 
163
  _memodict = {}
164
+ self._node = _Node(
165
+ expr=expression, _memodict=_memodict, _func_lookup=_func_lookup
166
+ )
167
  self._expression_string = str(expression)
168
  self._selection = selection
169
  self.symbols_in = [str(symbol) for symbol in symbols_in]
 
174
  def forward(self, X):
175
  if self._selection is not None:
176
  X = X[:, self._selection]
177
+ symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)}
 
178
  return self._node(symbols)
179
 
180
 
181
+ def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None):
 
182
  """Returns a module for a given sympy expression with trainable parameters;
183
 
184
  This function will assume the input to the module is a matrix X, where
 
188
 
189
  _initialize_torch()
190
 
191
+ return SingleSymPyModule(
192
+ expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings
193
+ )
pysr/feynman_problems.py CHANGED
@@ -7,6 +7,7 @@ from pathlib import Path
7
  PKG_DIR = Path(__file__).parents[1]
8
  FEYNMAN_DATASET = PKG_DIR / "datasets" / "FeynmanEquations.csv"
9
 
 
10
  class Problem:
11
  """
12
  Problem API to work with PySR.
@@ -15,6 +16,7 @@ class Problem:
15
 
16
  Should be able to call pysr(problem.X, problem.y, var_names=problem.var_names) and have it work
17
  """
 
18
  def __init__(self, X, y, form=None, variable_names=None):
19
  self.X = X
20
  self.y = y
@@ -27,34 +29,39 @@ class FeynmanProblem(Problem):
27
  Stores the data for the problems from the 100 Feynman Equations on Physics.
28
  This is the benchmark used in the AI Feynman Paper
29
  """
 
30
  def __init__(self, row, gen=False, dp=500):
31
  """
32
  row: a row read as a dict from the FeynmanEquations dataset provided in the datasets folder of the repo
33
  gen: If true the problem will have dp X and y values randomly generated else they will be None
34
  """
35
- self.eq_id = row['Filename']
36
- self.n_vars = int(row['# variables'])
37
- super(FeynmanProblem, self).__init__(None, None, form=row['Formula'],
38
- variable_names=[row[f'v{i + 1}_name'] for i in range(self.n_vars)])
39
- self.low = [float(row[f'v{i+1}_low']) for i in range(self.n_vars)]
40
- self.high = [float(row[f'v{i+1}_high']) for i in range(self.n_vars)]
41
- self.dp = dp
 
 
 
 
42
  if gen:
43
  self.X = np.random.uniform(0.01, 25, size=(self.dp, self.n_vars))
44
  d = {}
45
  for var in range(len(self.variable_names)):
46
  d[self.variable_names[var]] = self.X[:, var]
47
- d['exp'] = np.exp
48
- d['sqrt'] = np.sqrt
49
- d['pi'] = np.pi
50
- d['cos'] = np.cos
51
- d['sin'] = np.sin
52
- d['tan'] = np.tan
53
- d['tanh'] = np.tanh
54
- d['ln'] = np.log
55
- d['log'] = np.log # Quite sure the Feynman dataset has no base 10 logs
56
- d['arcsin'] = np.arcsin
57
- self.y = eval(self.form,d)
58
  return
59
 
60
  def __str__(self):
@@ -77,7 +84,8 @@ class FeynmanProblem(Problem):
77
  for i, row in enumerate(reader):
78
  if ind > first:
79
  break
80
- if row['Filename'] == '': continue
 
81
  try:
82
  p = FeynmanProblem(row, gen=gen, dp=dp)
83
  ret.append(p)
@@ -93,18 +101,34 @@ def run_on_problem(problem, verbosity=0, multiprocessing=True):
93
  Takes in a problem and returns a tuple: (equations, best predicted equation, actual equation)
94
  """
95
  from time import time
 
96
  starting = time()
97
- equations = pysr(problem.X, problem.y, variable_names=problem.variable_names, verbosity=verbosity,)
98
- timing = time()-starting
 
 
 
 
 
99
  others = {"time": timing, "problem": problem}
100
  if not multiprocessing:
101
- others['equations'] = equations
102
  return str(best(equations)), problem.form, others
103
 
104
- def do_feynman_experiments_parallel(first=100, verbosity=0, dp=500, output_file_path="FeynmanExperiment.csv", data_dir=FEYNMAN_DATASET):
 
 
 
 
 
 
 
105
  import multiprocessing as mp
106
  from tqdm import tqdm
107
- problems = FeynmanProblem.mk_problems(first=first, gen=True, dp=dp, data_dir=data_dir)
 
 
 
108
  ids = []
109
  predictions = []
110
  true_equations = []
@@ -117,22 +141,31 @@ def do_feynman_experiments_parallel(first=100, verbosity=0, dp=500, output_file_
117
  pbar.update()
118
  for res in results:
119
  prediction, true_equation, others = res
120
- problem = others['problem']
121
  ids.append(problem.eq_id)
122
  predictions.append(prediction)
123
  true_equations.append(true_equation)
124
- time_takens.append(others['time'])
125
- with open(output_file_path, 'a') as f:
126
- writer = csv.writer(f, delimiter=',')
127
- writer.writerow(['ID', 'Predicted', 'True', 'Time'])
128
  for i in range(len(ids)):
129
  writer.writerow([ids[i], predictions[i], true_equations[i], time_takens[i]])
130
  return
131
 
132
- def do_feynman_experiments(first=100, verbosity=0, dp=500, output_file_path="FeynmanExperiment.csv", data_dir=FEYNMAN_DATASET):
 
 
 
 
 
 
 
133
  from tqdm import tqdm
134
 
135
- problems = FeynmanProblem.mk_problems(first=first, gen=True, dp=dp, data_dir=data_dir)
 
 
136
  indx = range(len(problems))
137
  ids = []
138
  predictions = []
@@ -143,10 +176,10 @@ def do_feynman_experiments(first=100, verbosity=0, dp=500, output_file_path="Fey
143
  ids.append(problem.eq_id)
144
  predictions.append(prediction)
145
  true_equations.append(true_equation)
146
- time_takens.append(others['time'])
147
- with open(output_file_path, 'a') as f:
148
- writer = csv.writer(f, delimiter=',')
149
- writer.writerow(['ID', 'Predicted', 'True', 'Time'])
150
  for i in range(len(ids)):
151
  writer.writerow([ids[i], predictions[i], true_equations[i], time_takens[i]])
152
  return
 
7
  PKG_DIR = Path(__file__).parents[1]
8
  FEYNMAN_DATASET = PKG_DIR / "datasets" / "FeynmanEquations.csv"
9
 
10
+
11
  class Problem:
12
  """
13
  Problem API to work with PySR.
 
16
 
17
  Should be able to call pysr(problem.X, problem.y, var_names=problem.var_names) and have it work
18
  """
19
+
20
  def __init__(self, X, y, form=None, variable_names=None):
21
  self.X = X
22
  self.y = y
 
29
  Stores the data for the problems from the 100 Feynman Equations on Physics.
30
  This is the benchmark used in the AI Feynman Paper
31
  """
32
+
33
  def __init__(self, row, gen=False, dp=500):
34
  """
35
  row: a row read as a dict from the FeynmanEquations dataset provided in the datasets folder of the repo
36
  gen: If true the problem will have dp X and y values randomly generated else they will be None
37
  """
38
+ self.eq_id = row["Filename"]
39
+ self.n_vars = int(row["# variables"])
40
+ super(FeynmanProblem, self).__init__(
41
+ None,
42
+ None,
43
+ form=row["Formula"],
44
+ variable_names=[row[f"v{i + 1}_name"] for i in range(self.n_vars)],
45
+ )
46
+ self.low = [float(row[f"v{i+1}_low"]) for i in range(self.n_vars)]
47
+ self.high = [float(row[f"v{i+1}_high"]) for i in range(self.n_vars)]
48
+ self.dp = dp
49
  if gen:
50
  self.X = np.random.uniform(0.01, 25, size=(self.dp, self.n_vars))
51
  d = {}
52
  for var in range(len(self.variable_names)):
53
  d[self.variable_names[var]] = self.X[:, var]
54
+ d["exp"] = np.exp
55
+ d["sqrt"] = np.sqrt
56
+ d["pi"] = np.pi
57
+ d["cos"] = np.cos
58
+ d["sin"] = np.sin
59
+ d["tan"] = np.tan
60
+ d["tanh"] = np.tanh
61
+ d["ln"] = np.log
62
+ d["log"] = np.log # Quite sure the Feynman dataset has no base 10 logs
63
+ d["arcsin"] = np.arcsin
64
+ self.y = eval(self.form, d)
65
  return
66
 
67
  def __str__(self):
 
84
  for i, row in enumerate(reader):
85
  if ind > first:
86
  break
87
+ if row["Filename"] == "":
88
+ continue
89
  try:
90
  p = FeynmanProblem(row, gen=gen, dp=dp)
91
  ret.append(p)
 
101
  Takes in a problem and returns a tuple: (equations, best predicted equation, actual equation)
102
  """
103
  from time import time
104
+
105
  starting = time()
106
+ equations = pysr(
107
+ problem.X,
108
+ problem.y,
109
+ variable_names=problem.variable_names,
110
+ verbosity=verbosity,
111
+ )
112
+ timing = time() - starting
113
  others = {"time": timing, "problem": problem}
114
  if not multiprocessing:
115
+ others["equations"] = equations
116
  return str(best(equations)), problem.form, others
117
 
118
+
119
+ def do_feynman_experiments_parallel(
120
+ first=100,
121
+ verbosity=0,
122
+ dp=500,
123
+ output_file_path="FeynmanExperiment.csv",
124
+ data_dir=FEYNMAN_DATASET,
125
+ ):
126
  import multiprocessing as mp
127
  from tqdm import tqdm
128
+
129
+ problems = FeynmanProblem.mk_problems(
130
+ first=first, gen=True, dp=dp, data_dir=data_dir
131
+ )
132
  ids = []
133
  predictions = []
134
  true_equations = []
 
141
  pbar.update()
142
  for res in results:
143
  prediction, true_equation, others = res
144
+ problem = others["problem"]
145
  ids.append(problem.eq_id)
146
  predictions.append(prediction)
147
  true_equations.append(true_equation)
148
+ time_takens.append(others["time"])
149
+ with open(output_file_path, "a") as f:
150
+ writer = csv.writer(f, delimiter=",")
151
+ writer.writerow(["ID", "Predicted", "True", "Time"])
152
  for i in range(len(ids)):
153
  writer.writerow([ids[i], predictions[i], true_equations[i], time_takens[i]])
154
  return
155
 
156
+
157
+ def do_feynman_experiments(
158
+ first=100,
159
+ verbosity=0,
160
+ dp=500,
161
+ output_file_path="FeynmanExperiment.csv",
162
+ data_dir=FEYNMAN_DATASET,
163
+ ):
164
  from tqdm import tqdm
165
 
166
+ problems = FeynmanProblem.mk_problems(
167
+ first=first, gen=True, dp=dp, data_dir=data_dir
168
+ )
169
  indx = range(len(problems))
170
  ids = []
171
  predictions = []
 
176
  ids.append(problem.eq_id)
177
  predictions.append(prediction)
178
  true_equations.append(true_equation)
179
+ time_takens.append(others["time"])
180
+ with open(output_file_path, "a") as f:
181
+ writer = csv.writer(f, delimiter=",")
182
+ writer.writerow(["ID", "Predicted", "True", "Time"])
183
  for i in range(len(ids)):
184
  writer.writerow([ids[i], predictions[i], true_equations[i], time_takens[i]])
185
  return
pysr/sr.py CHANGED
@@ -15,7 +15,7 @@ from datetime import datetime
15
  import warnings
16
 
17
  global_state = dict(
18
- equation_file='hall_of_fame.csv',
19
  n_features=None,
20
  variable_names=[],
21
  extra_sympy_mappings={},
@@ -25,108 +25,112 @@ global_state = dict(
25
  output_torch_format=False,
26
  multioutput=False,
27
  nout=1,
28
- selection=None
29
  )
30
 
31
  sympy_mappings = {
32
- 'div': lambda x, y : x/y,
33
- 'mult': lambda x, y : x*y,
34
- 'sqrt_abs':lambda x : sympy.sqrt(abs(x)),
35
- 'square':lambda x : x**2,
36
- 'cube': lambda x : x**3,
37
- 'plus': lambda x, y : x + y,
38
- 'sub': lambda x, y : x - y,
39
- 'neg': lambda x : -x,
40
- 'pow': lambda x, y : abs(x)**y,
41
- 'cos': lambda x : sympy.cos(x),
42
- 'sin': lambda x : sympy.sin(x),
43
- 'tan': lambda x : sympy.tan(x),
44
- 'cosh': lambda x : sympy.cosh(x),
45
- 'sinh': lambda x : sympy.sinh(x),
46
- 'tanh': lambda x : sympy.tanh(x),
47
- 'exp': lambda x : sympy.exp(x),
48
- 'acos': lambda x : sympy.acos(x),
49
- 'asin': lambda x : sympy.asin(x),
50
- 'atan': lambda x : sympy.atan(x),
51
- 'acosh':lambda x : sympy.acosh(abs(x) + 1),
52
- 'acosh_abs':lambda x : sympy.acosh(abs(x) + 1),
53
- 'asinh':lambda x : sympy.asinh(x),
54
- 'atanh':lambda x : sympy.atanh(sympy.Mod(x+1, 2)-1),
55
- 'atanh_clip':lambda x : sympy.atanh(sympy.Mod(x+1, 2)-1),
56
- 'abs': lambda x : abs(x),
57
- 'mod': lambda x, y : sympy.Mod(x, y),
58
- 'erf': lambda x : sympy.erf(x),
59
- 'erfc': lambda x : sympy.erfc(x),
60
- 'log_abs': lambda x : sympy.log(abs(x)),
61
- 'log10_abs':lambda x : sympy.log(abs(x), 10),
62
- 'log2_abs': lambda x : sympy.log(abs(x), 2),
63
- 'log1p_abs': lambda x : sympy.log(abs(x) + 1),
64
- 'floor': lambda x : sympy.floor(x),
65
- 'ceil': lambda x : sympy.ceil(x),
66
- 'sign': lambda x : sympy.sign(x),
67
- 'gamma': lambda x : sympy.gamma(x),
68
  }
69
 
70
- def pysr(X, y, weights=None,
71
- binary_operators=None,
72
- unary_operators=None,
73
- procs=4,
74
- loss='L2DistLoss()',
75
- populations=20,
76
- niterations=100,
77
- ncyclesperiteration=300,
78
- alpha=0.1,
79
- annealing=False,
80
- fractionReplaced=0.10,
81
- fractionReplacedHof=0.10,
82
- npop=1000,
83
- parsimony=1e-4,
84
- migration=True,
85
- hofMigration=True,
86
- shouldOptimizeConstants=True,
87
- topn=10,
88
- weightAddNode=1,
89
- weightInsertNode=3,
90
- weightDeleteNode=3,
91
- weightDoNothing=1,
92
- weightMutateConstant=10,
93
- weightMutateOperator=1,
94
- weightRandomize=1,
95
- weightSimplify=0.01,
96
- perturbationFactor=1.0,
97
- timeout=None,
98
- extra_sympy_mappings=None,
99
- extra_torch_mappings=None,
100
- extra_jax_mappings=None,
101
- equation_file=None,
102
- verbosity=1e9,
103
- progress=None,
104
- maxsize=20,
105
- fast_cycle=False,
106
- maxdepth=None,
107
- variable_names=None,
108
- batching=False,
109
- batchSize=50,
110
- select_k_features=None,
111
- warmupMaxsizeBy=0.0,
112
- constraints=None,
113
- useFrequency=True,
114
- tempdir=None,
115
- delete_tempfiles=True,
116
- julia_optimization=3,
117
- julia_project=None,
118
- user_input=True,
119
- update=True,
120
- temp_equation_file=False,
121
- output_jax_format=False,
122
- output_torch_format=False,
123
- optimizer_algorithm="BFGS",
124
- optimizer_nrestarts=3,
125
- optimize_probability=1.0,
126
- optimizer_iterations=10,
127
- tournament_selection_n=10,
128
- tournament_selection_p=1.0
129
- ):
 
 
 
 
130
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
131
  Note: most default parameters have been tuned over several example
132
  equations, but you should adjust `niterations`,
@@ -244,7 +248,7 @@ def pysr(X, y, weights=None,
244
  :type: pd.DataFrame/list
245
  """
246
  if binary_operators is None:
247
- binary_operators = '+ * - /'.split(' ')
248
  if unary_operators is None:
249
  unary_operators = []
250
  if extra_sympy_mappings is None:
@@ -255,16 +259,18 @@ def pysr(X, y, weights=None,
255
  constraints = {}
256
 
257
  if progress is not None:
258
- if progress and ('buffer' not in sys.stdout.__dir__()):
259
- warnings.warn("Note: it looks like you are running in Jupyter. The progress bar will be turned off.")
 
 
260
  progress = False
261
  else:
262
- if 'buffer' in sys.stdout.__dir__():
263
  progress = True
264
  else:
265
  progress = False
266
 
267
- assert optimizer_algorithm in ['NelderMead', 'BFGS']
268
  assert tournament_selection_n < npop
269
 
270
  if isinstance(X, pd.DataFrame):
@@ -275,25 +281,34 @@ def pysr(X, y, weights=None,
275
  X = X[:, None]
276
 
277
  if len(variable_names) == 0:
278
- variable_names = [f'x{i}' for i in range(X.shape[1])]
279
-
280
- use_custom_variable_names = (len(variable_names) != 0)
281
-
282
- _check_assertions(X, binary_operators, unary_operators,
283
- use_custom_variable_names, variable_names, weights, y)
 
 
 
 
 
 
 
284
  _check_for_julia_installation()
285
 
286
-
287
  if len(X) > 10000 and not batching:
288
- warnings.warn("Note: you are running with more than 10,000 datapoints. You should consider turning on batching (https://pysr.readthedocs.io/en/latest/docs/options/#batching). You should also reconsider if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form with symbolic regression. More datapoints will lower the search speed.")
 
 
289
 
290
  if maxsize > 40:
291
- warnings.warn("Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`.")
 
 
292
 
293
  X, variable_names, selection = _handle_feature_selection(
294
- X, select_k_features,
295
- use_custom_variable_names, variable_names, y
296
- )
297
 
298
  if maxdepth is None:
299
  maxdepth = maxsize
@@ -312,81 +327,102 @@ def pysr(X, y, weights=None,
312
  else:
313
  raise NotImplementedError("y shape not supported!")
314
 
315
- kwargs = dict(X=X, y=y, weights=weights,
316
- alpha=alpha, annealing=annealing, batchSize=batchSize,
317
- batching=batching, binary_operators=binary_operators,
318
- fast_cycle=fast_cycle,
319
- fractionReplaced=fractionReplaced,
320
- ncyclesperiteration=ncyclesperiteration,
321
- niterations=niterations, npop=npop, topn=topn,
322
- verbosity=verbosity, progress=progress, update=update,
323
- julia_optimization=julia_optimization, timeout=timeout,
324
- fractionReplacedHof=fractionReplacedHof,
325
- hofMigration=hofMigration, maxdepth=maxdepth,
326
- maxsize=maxsize, migration=migration,
327
- optimizer_algorithm=optimizer_algorithm,
328
- optimizer_nrestarts=optimizer_nrestarts,
329
- optimize_probability=optimize_probability,
330
- optimizer_iterations=optimizer_iterations,
331
- parsimony=parsimony, perturbationFactor=perturbationFactor,
332
- populations=populations, procs=procs,
333
- shouldOptimizeConstants=shouldOptimizeConstants,
334
- unary_operators=unary_operators, useFrequency=useFrequency,
335
- use_custom_variable_names=use_custom_variable_names,
336
- variable_names=variable_names, warmupMaxsizeBy=warmupMaxsizeBy,
337
- weightAddNode=weightAddNode,
338
- weightDeleteNode=weightDeleteNode,
339
- weightDoNothing=weightDoNothing,
340
- weightInsertNode=weightInsertNode,
341
- weightMutateConstant=weightMutateConstant,
342
- weightMutateOperator=weightMutateOperator,
343
- weightRandomize=weightRandomize,
344
- weightSimplify=weightSimplify,
345
- constraints=constraints,
346
- extra_sympy_mappings=extra_sympy_mappings,
347
- extra_jax_mappings=extra_jax_mappings,
348
- extra_torch_mappings=extra_torch_mappings,
349
- julia_project=julia_project, loss=loss,
350
- output_jax_format=output_jax_format,
351
- output_torch_format=output_torch_format,
352
- selection=selection,
353
- multioutput=multioutput, nout=nout,
354
- tournament_selection_n=tournament_selection_n,
355
- tournament_selection_p=tournament_selection_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  kwargs = {**_set_paths(tempdir), **kwargs}
358
 
359
  if temp_equation_file:
360
- equation_file = kwargs['tmpdir'] / f'hall_of_fame.csv'
361
  elif equation_file is None:
362
  date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
363
- equation_file = 'hall_of_fame_' + date_time + '.csv'
364
 
365
  kwargs = {**dict(equation_file=equation_file), **kwargs}
366
 
367
-
368
- pkg_directory = kwargs['pkg_directory']
369
  manifest_file = None
370
- if kwargs['julia_project'] is not None:
371
- manifest_filepath = Path(kwargs['julia_project']) / 'Manifest.toml'
372
  else:
373
- manifest_filepath = pkg_directory / 'Manifest.toml'
374
 
375
- kwargs['need_install'] = False
376
 
377
  if not (manifest_filepath).is_file():
378
- kwargs['need_install'] = (not user_input) or _yesno("I will install Julia packages using PySR's Project.toml file. OK?")
379
- if kwargs['need_install']:
 
 
380
  print("OK. I will install at launch.")
381
  assert update
382
 
383
- kwargs['def_hyperparams'] = _create_inline_operators(**kwargs)
384
 
385
  _handle_constraints(**kwargs)
386
 
387
- kwargs['constraints_str'] = _make_constraints_str(**kwargs)
388
- kwargs['def_hyperparams'] = _make_hyperparams_julia_str(**kwargs)
389
- kwargs['def_datasets'] = _make_datasets_julia_str(**kwargs)
390
 
391
  _create_julia_files(**kwargs)
392
  _final_pysr_process(**kwargs)
@@ -395,7 +431,7 @@ def pysr(X, y, weights=None,
395
  equations = get_hof(**kwargs)
396
 
397
  if delete_tempfiles:
398
- shutil.rmtree(kwargs['tmpdir'])
399
 
400
  return equations
401
 
@@ -403,7 +439,7 @@ def pysr(X, y, weights=None,
403
  def _set_globals(X, **kwargs):
404
  global global_state
405
 
406
- global_state['n_features'] = X.shape[1]
407
  for key, value in kwargs.items():
408
  if key in global_state:
409
  global_state[key] = value
@@ -411,34 +447,37 @@ def _set_globals(X, **kwargs):
411
 
412
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
413
  command = [
414
- f'julia', f'-O{julia_optimization:d}',
 
415
  str(runfile_filename),
416
  ]
417
  if timeout is not None:
418
- command = [f'timeout', f'{timeout}'] + command
419
  _cmd_runner(command, **kwargs)
420
 
 
421
  def _cmd_runner(command, progress, **kwargs):
422
- if kwargs['verbosity'] > 0:
423
- print("Running on", ' '.join(command))
424
  process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=-1)
425
  try:
426
  while True:
427
  line = process.stdout.readline()
428
- if not line: break
429
- decoded_line = line.decode('utf-8')
 
430
  if progress:
431
- decoded_line = (decoded_line
432
- .replace('\\033[K', '\033[K')
433
- .replace('\\033[1A', '\033[1A')
434
- .replace('\\033[1B', '\033[1B')
435
- .replace('\\r', '\r')
436
- .encode(sys.stdout.encoding, errors='replace')
437
- )
438
  sys.stdout.buffer.write(decoded_line)
439
  sys.stdout.flush()
440
  else:
441
- print(decoded_line, end='')
442
 
443
  process.stdout.close()
444
  process.wait()
@@ -446,62 +485,94 @@ def _cmd_runner(command, progress, **kwargs):
446
  print("Killing process... will return when done.")
447
  process.kill()
448
 
449
- def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, def_hyperparams,
450
- fractionReplaced, ncyclesperiteration, niterations, npop,
451
- runfile_filename, topn, verbosity, julia_project, procs, weights,
452
- X, variable_names, pkg_directory, need_install, update, **kwargs):
453
- with open(hyperparam_filename, 'w') as f:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  print(def_hyperparams, file=f)
455
- with open(dataset_filename, 'w') as f:
456
  print(def_datasets, file=f)
457
- with open(runfile_filename, 'w') as f:
458
  if julia_project is None:
459
  julia_project = pkg_directory
460
  else:
461
  julia_project = Path(julia_project)
462
- print(f'import Pkg', file=f)
463
  print(f'Pkg.activate("{_escape_filename(julia_project)}")', file=f)
464
  if need_install:
465
- print(f'Pkg.instantiate()', file=f)
466
- print(f'Pkg.update()', file=f)
467
- print(f'Pkg.precompile()', file=f)
468
  elif update:
469
- print(f'Pkg.update()', file=f)
470
- print(f'using SymbolicRegression', file=f)
471
  print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
472
  print(f'include("{_escape_filename(dataset_filename)}")', file=f)
473
  if len(variable_names) == 0:
474
  varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
475
  else:
476
- varMap = "[" + ",".join(['"' + vname + '"' for vname in variable_names]) + "]"
 
 
477
 
478
  if weights is not None:
479
- print(f'EquationSearch(X, y, weights=weights, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
 
 
 
480
  else:
481
- print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
 
 
 
482
 
483
 
484
- def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
485
- multioutput, **kwargs):
 
486
  def_datasets = """using DelimitedFiles"""
487
- np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
488
  if multioutput:
489
- np.savetxt(y_filename, y.astype(np.float32), delimiter=',')
490
  else:
491
- np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
492
  if weights is not None:
493
  if multioutput:
494
- np.savetxt(weights_filename, weights.astype(np.float32), delimiter=',')
495
  else:
496
- np.savetxt(weights_filename, weights.reshape(-1, 1).astype(np.float32), delimiter=',')
 
 
 
 
497
  def_datasets += f"""
498
  X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
499
 
500
  if multioutput:
501
- def_datasets+= f"""
502
  y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
503
  else:
504
- def_datasets+= f"""
505
  y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
506
 
507
  if weights is not None:
@@ -513,30 +584,69 @@ weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Fl
513
  weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
514
  return def_datasets
515
 
516
- def _make_hyperparams_julia_str(X, alpha, annealing, batchSize, batching, binary_operators, constraints_str,
517
- def_hyperparams, equation_file, fast_cycle, fractionReplacedHof, hofMigration,
518
- maxdepth, maxsize, migration,
519
- optimizer_algorithm, optimizer_nrestarts,
520
- optimize_probability, optimizer_iterations, npop,
521
- parsimony, perturbationFactor, populations, procs, shouldOptimizeConstants,
522
- unary_operators, useFrequency, use_custom_variable_names,
523
- variable_names, warmupMaxsizeBy, weightAddNode,
524
- ncyclesperiteration, fractionReplaced, topn, verbosity, progress, loss,
525
- weightDeleteNode, weightDoNothing, weightInsertNode, weightMutateConstant,
526
- weightMutateOperator, weightRandomize, weightSimplify, weights,
527
- tournament_selection_n, tournament_selection_p,
528
- **kwargs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
529
  try:
530
  term_width = shutil.get_terminal_size().columns
531
  except:
532
- _, term_width = subprocess.check_output(['stty', 'size']).split()
 
533
  def tuple_fix(ops):
534
  if len(ops) > 1:
535
- return ', '.join(ops)
536
  elif len(ops) == 0:
537
- return ''
538
  else:
539
- return ops[0] + ','
540
 
541
  def_hyperparams += f"""\n
542
  plus=(+)
@@ -606,7 +716,7 @@ progress={'true' if progress else 'false'},
606
  terminal_width={term_width:d}
607
  """
608
 
609
- def_hyperparams += '\n)'
610
  return def_hyperparams
611
 
612
 
@@ -639,16 +749,20 @@ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs
639
  for op in binary_operators:
640
  if op not in constraints:
641
  constraints[op] = (-1, -1)
642
- if op in ['plus', 'sub']:
643
  if constraints[op][0] != constraints[op][1]:
644
  raise NotImplementedError(
645
- "You need equal constraints on both sides for - and *, due to simplification strategies.")
646
- elif op == 'mult':
 
647
  # Make sure the complex expression is in the left side.
648
  if constraints[op][0] == -1:
649
  continue
650
  elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
651
- constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
 
 
 
652
 
653
 
654
  def _create_inline_operators(binary_operators, unary_operators, **kwargs):
@@ -656,27 +770,33 @@ def _create_inline_operators(binary_operators, unary_operators, **kwargs):
656
  for op_list in [binary_operators, unary_operators]:
657
  for i in range(len(op_list)):
658
  op = op_list[i]
659
- is_user_defined_operator = '(' in op
660
 
661
  if is_user_defined_operator:
662
  def_hyperparams += op + "\n"
663
  # Cut off from the first non-alphanumeric char:
664
  first_non_char = [
665
- j for j in range(len(op))
666
- if not (op[j].isalpha() or op[j].isdigit())][0]
 
 
667
  function_name = op[:first_non_char]
668
  op_list[i] = function_name
669
  return def_hyperparams
670
 
671
 
672
- def _handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y):
 
 
673
  if select_k_features is not None:
674
  selection = run_feature_selection(X, y, select_k_features)
675
  print(f"Using features {selection}")
676
  X = X[:, selection]
677
 
678
  if use_custom_variable_names:
679
- variable_names = [variable_names[selection[i]] for i in range(len(selection))]
 
 
680
  else:
681
  selection = None
682
  return X, variable_names, selection
@@ -687,22 +807,34 @@ def _set_paths(tempdir):
687
  pkg_directory = Path(__file__).parents[1]
688
  default_project_file = pkg_directory / "Project.toml"
689
  tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
690
- hyperparam_filename = tmpdir / f'hyperparams.jl'
691
- dataset_filename = tmpdir / f'dataset.jl'
692
- runfile_filename = tmpdir / f'runfile.jl'
693
  X_filename = tmpdir / "X.csv"
694
  y_filename = tmpdir / "y.csv"
695
  weights_filename = tmpdir / "weights.csv"
696
- return dict(pkg_directory=pkg_directory,
697
- default_project_file=default_project_file,
698
- X_filename=X_filename,
699
- dataset_filename=dataset_filename,
700
- hyperparam_filename=hyperparam_filename,
701
- runfile_filename=runfile_filename, tmpdir=tmpdir,
702
- weights_filename=weights_filename, y_filename=y_filename)
703
-
704
-
705
- def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):
 
 
 
 
 
 
 
 
 
 
 
 
706
  # Check for potential errors before they happen
707
  assert len(unary_operators) + len(binary_operators) > 0
708
  assert len(X.shape) == 2
@@ -714,76 +846,108 @@ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_
714
  if use_custom_variable_names:
715
  assert len(variable_names) == X.shape[1]
716
 
 
717
  def _check_for_julia_installation():
718
  try:
719
  process = subprocess.Popen(["julia", "-v"], stdout=subprocess.PIPE, bufsize=-1)
720
  while True:
721
  line = process.stdout.readline()
722
- if not line: break
 
723
  process.stdout.close()
724
  process.wait()
725
  except FileNotFoundError:
726
  import os
727
- raise RuntimeError(f"Your current $PATH is: {os.environ['PATH']}\nPySR could not start julia. Make sure julia is installed and on your $PATH.")
 
 
 
728
  process.kill()
729
 
730
 
731
  def run_feature_selection(X, y, select_k_features):
732
  """Use a gradient boosting tree regressor as a proxy for finding
733
- the k most important features in X, returning indices for those
734
- features as output."""
735
 
736
  from sklearn.ensemble import RandomForestRegressor
737
  from sklearn.feature_selection import SelectFromModel, SelectKBest
738
 
739
  clf = RandomForestRegressor(n_estimators=100, max_depth=3, random_state=0)
740
  clf.fit(X, y)
741
- selector = SelectFromModel(clf, threshold=-np.inf,
742
- max_features=select_k_features, prefit=True)
 
743
  return selector.get_support(indices=True)
744
 
745
- def get_hof(equation_file=None, n_features=None, variable_names=None,
746
- output_jax_format=None, output_torch_format=None,
747
- selection=None, extra_sympy_mappings=None,
748
- extra_jax_mappings=None, extra_torch_mappings=None,
749
- multioutput=None, nout=None, **kwargs):
 
 
 
 
 
 
 
 
 
 
750
  """Get the equations from a hall of fame file. If no arguments
751
  entered, the ones used previously from a call to PySR will be used."""
752
 
753
  global global_state
754
 
755
- if equation_file is None: equation_file = global_state['equation_file']
756
- if n_features is None: n_features = global_state['n_features']
757
- if variable_names is None: variable_names = global_state['variable_names']
758
- if extra_sympy_mappings is None: extra_sympy_mappings = global_state['extra_sympy_mappings']
759
- if extra_jax_mappings is None: extra_jax_mappings = global_state['extra_jax_mappings']
760
- if extra_torch_mappings is None: extra_torch_mappings = global_state['extra_torch_mappings']
761
- if output_torch_format is None: output_torch_format = global_state['output_torch_format']
762
- if output_jax_format is None: output_jax_format = global_state['output_jax_format']
763
- if multioutput is None: multioutput = global_state['multioutput']
764
- if nout is None: nout = global_state['nout']
765
- if selection is None: selection = global_state['selection']
766
-
767
- global_state['selection'] = selection
768
- global_state['equation_file'] = equation_file
769
- global_state['n_features'] = n_features
770
- global_state['variable_names'] = variable_names
771
- global_state['extra_sympy_mappings'] = extra_sympy_mappings
772
- global_state['extra_jax_mappings'] = extra_jax_mappings
773
- global_state['extra_torch_mappings'] = extra_torch_mappings
774
- global_state['output_torch_format'] = output_torch_format
775
- global_state['output_jax_format'] = output_jax_format
776
- global_state['multioutput'] = multioutput
777
- global_state['nout'] = nout
778
- global_state['selection'] = selection
 
 
 
 
 
 
 
 
 
 
 
779
 
780
  try:
781
  if multioutput:
782
- all_outputs = [pd.read_csv(str(equation_file) + f'.out{i}' + '.bkup', sep="|") for i in range(1, nout+1)]
 
 
 
783
  else:
784
- all_outputs = [pd.read_csv(str(equation_file) + '.bkup', sep="|")]
785
  except FileNotFoundError:
786
- raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
 
 
787
 
788
  ret_outputs = []
789
 
@@ -798,19 +962,16 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
798
  jax_format = []
799
  if output_torch_format:
800
  torch_format = []
801
- use_custom_variable_names = (len(variable_names) != 0)
802
- local_sympy_mappings = {
803
- **extra_sympy_mappings,
804
- **sympy_mappings
805
- }
806
 
807
  if use_custom_variable_names:
808
  sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
809
  else:
810
- sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
811
 
812
  for i in range(len(output)):
813
- eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
814
  sympy_format.append(eqn)
815
 
816
  # Numpy:
@@ -819,37 +980,46 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
819
  # JAX:
820
  if output_jax_format:
821
  from .export_jax import sympy2jax
 
822
  func, params = sympy2jax(eqn, sympy_symbols, selection)
823
- jax_format.append({'callable': func, 'parameters': params})
824
 
825
  # Torch:
826
  if output_torch_format:
827
  from .export_torch import sympy2torch
 
828
  module = sympy2torch(eqn, sympy_symbols, selection=selection)
829
  torch_format.append(module)
830
 
831
- curMSE = output.loc[i, 'MSE']
832
- curComplexity = output.loc[i, 'Complexity']
833
 
834
  if lastMSE is None:
835
  cur_score = 0.0
836
  else:
837
- cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
838
 
839
  scores.append(cur_score)
840
  lastMSE = curMSE
841
  lastComplexity = curComplexity
842
 
843
- output['score'] = np.array(scores)
844
- output['sympy_format'] = sympy_format
845
- output['lambda_format'] = lambda_format
846
- output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
 
 
 
 
 
 
 
847
  if output_jax_format:
848
- output_cols += ['jax_format']
849
- output['jax_format'] = jax_format
850
  if output_torch_format:
851
- output_cols += ['torch_format']
852
- output['torch_format'] = torch_format
853
 
854
  ret_outputs.append(output[output_cols])
855
 
@@ -858,67 +1028,80 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
858
  else:
859
  return ret_outputs[0]
860
 
 
861
  def best_row(equations=None):
862
  """Return the best row of a hall of fame file using the score column.
863
  By default this uses the last equation file.
864
  """
865
- if equations is None: equations = get_hof()
 
866
  if isinstance(equations, list):
867
- return [eq.iloc[np.argmax(eq['score'])] for eq in equations]
868
  else:
869
- return equations.iloc[np.argmax(equations['score'])]
 
870
 
871
  def best_tex(equations=None):
872
  """Return the equation with the best score, in latex format
873
  By default this uses the last equation file.
874
  """
875
- if equations is None: equations = get_hof()
 
876
  if isinstance(equations, list):
877
- return [sympy.latex(best_row(eq)['sympy_format'].simplify()) for eq in equations]
 
 
878
  else:
879
- return sympy.latex(best_row(equations)['sympy_format'].simplify())
 
880
 
881
  def best(equations=None):
882
  """Return the equation with the best score, in sympy format.
883
  By default this uses the last equation file.
884
  """
885
- if equations is None: equations = get_hof()
 
886
  if isinstance(equations, list):
887
- return [best_row(eq)['sympy_format'].simplify() for eq in equations]
888
  else:
889
- return best_row(equations)['sympy_format'].simplify()
 
890
 
891
  def best_callable(equations=None):
892
  """Return the equation with the best score, in callable format.
893
  By default this uses the last equation file.
894
  """
895
- if equations is None: equations = get_hof()
 
896
  if isinstance(equations, list):
897
- return [best_row(eq)['lambda_format'] for eq in equations]
898
  else:
899
- return best_row(equations)['lambda_format']
 
900
 
901
  def _escape_filename(filename):
902
  """Turns a file into a string representation with correctly escaped backslashes"""
903
  repr = str(filename)
904
- repr = repr.replace('\\', '\\\\')
905
  return repr
906
 
 
907
  # https://gist.github.com/garrettdreyfus/8153571
908
  def _yesno(question):
909
  """Simple Yes/No Function."""
910
- prompt = f'{question} (y/n): '
911
  ans = input(prompt).strip().lower()
912
- if ans not in ['y', 'n']:
913
- print(f'{ans} is invalid, please try again...')
914
  return _yesno(question)
915
- if ans == 'y':
916
  return True
917
  return False
918
 
919
 
920
  class CallableEquation(object):
921
  """Simple wrapper for numpy lambda functions built with sympy"""
 
922
  def __init__(self, sympy_symbols, eqn, selection=None):
923
  self._sympy = eqn
924
  self._sympy_symbols = sympy_symbols
@@ -933,4 +1116,3 @@ class CallableEquation(object):
933
  return self._lambda(*X[:, self._selection].T)
934
  else:
935
  return self._lambda(*X.T)
936
-
 
15
  import warnings
16
 
17
  global_state = dict(
18
+ equation_file="hall_of_fame.csv",
19
  n_features=None,
20
  variable_names=[],
21
  extra_sympy_mappings={},
 
25
  output_torch_format=False,
26
  multioutput=False,
27
  nout=1,
28
+ selection=None,
29
  )
30
 
31
  sympy_mappings = {
32
+ "div": lambda x, y: x / y,
33
+ "mult": lambda x, y: x * y,
34
+ "sqrt_abs": lambda x: sympy.sqrt(abs(x)),
35
+ "square": lambda x: x ** 2,
36
+ "cube": lambda x: x ** 3,
37
+ "plus": lambda x, y: x + y,
38
+ "sub": lambda x, y: x - y,
39
+ "neg": lambda x: -x,
40
+ "pow": lambda x, y: abs(x) ** y,
41
+ "cos": lambda x: sympy.cos(x),
42
+ "sin": lambda x: sympy.sin(x),
43
+ "tan": lambda x: sympy.tan(x),
44
+ "cosh": lambda x: sympy.cosh(x),
45
+ "sinh": lambda x: sympy.sinh(x),
46
+ "tanh": lambda x: sympy.tanh(x),
47
+ "exp": lambda x: sympy.exp(x),
48
+ "acos": lambda x: sympy.acos(x),
49
+ "asin": lambda x: sympy.asin(x),
50
+ "atan": lambda x: sympy.atan(x),
51
+ "acosh": lambda x: sympy.acosh(abs(x) + 1),
52
+ "acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
53
+ "asinh": lambda x: sympy.asinh(x),
54
+ "atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
55
+ "atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - 1),
56
+ "abs": lambda x: abs(x),
57
+ "mod": lambda x, y: sympy.Mod(x, y),
58
+ "erf": lambda x: sympy.erf(x),
59
+ "erfc": lambda x: sympy.erfc(x),
60
+ "log_abs": lambda x: sympy.log(abs(x)),
61
+ "log10_abs": lambda x: sympy.log(abs(x), 10),
62
+ "log2_abs": lambda x: sympy.log(abs(x), 2),
63
+ "log1p_abs": lambda x: sympy.log(abs(x) + 1),
64
+ "floor": lambda x: sympy.floor(x),
65
+ "ceil": lambda x: sympy.ceil(x),
66
+ "sign": lambda x: sympy.sign(x),
67
+ "gamma": lambda x: sympy.gamma(x),
68
  }
69
 
70
+
71
+ def pysr(
72
+ X,
73
+ y,
74
+ weights=None,
75
+ binary_operators=None,
76
+ unary_operators=None,
77
+ procs=4,
78
+ loss="L2DistLoss()",
79
+ populations=20,
80
+ niterations=100,
81
+ ncyclesperiteration=300,
82
+ alpha=0.1,
83
+ annealing=False,
84
+ fractionReplaced=0.10,
85
+ fractionReplacedHof=0.10,
86
+ npop=1000,
87
+ parsimony=1e-4,
88
+ migration=True,
89
+ hofMigration=True,
90
+ shouldOptimizeConstants=True,
91
+ topn=10,
92
+ weightAddNode=1,
93
+ weightInsertNode=3,
94
+ weightDeleteNode=3,
95
+ weightDoNothing=1,
96
+ weightMutateConstant=10,
97
+ weightMutateOperator=1,
98
+ weightRandomize=1,
99
+ weightSimplify=0.01,
100
+ perturbationFactor=1.0,
101
+ timeout=None,
102
+ extra_sympy_mappings=None,
103
+ extra_torch_mappings=None,
104
+ extra_jax_mappings=None,
105
+ equation_file=None,
106
+ verbosity=1e9,
107
+ progress=None,
108
+ maxsize=20,
109
+ fast_cycle=False,
110
+ maxdepth=None,
111
+ variable_names=None,
112
+ batching=False,
113
+ batchSize=50,
114
+ select_k_features=None,
115
+ warmupMaxsizeBy=0.0,
116
+ constraints=None,
117
+ useFrequency=True,
118
+ tempdir=None,
119
+ delete_tempfiles=True,
120
+ julia_optimization=3,
121
+ julia_project=None,
122
+ user_input=True,
123
+ update=True,
124
+ temp_equation_file=False,
125
+ output_jax_format=False,
126
+ output_torch_format=False,
127
+ optimizer_algorithm="BFGS",
128
+ optimizer_nrestarts=3,
129
+ optimize_probability=1.0,
130
+ optimizer_iterations=10,
131
+ tournament_selection_n=10,
132
+ tournament_selection_p=1.0,
133
+ ):
134
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
135
  Note: most default parameters have been tuned over several example
136
  equations, but you should adjust `niterations`,
 
248
  :type: pd.DataFrame/list
249
  """
250
  if binary_operators is None:
251
+ binary_operators = "+ * - /".split(" ")
252
  if unary_operators is None:
253
  unary_operators = []
254
  if extra_sympy_mappings is None:
 
259
  constraints = {}
260
 
261
  if progress is not None:
262
+ if progress and ("buffer" not in sys.stdout.__dir__()):
263
+ warnings.warn(
264
+ "Note: it looks like you are running in Jupyter. The progress bar will be turned off."
265
+ )
266
  progress = False
267
  else:
268
+ if "buffer" in sys.stdout.__dir__():
269
  progress = True
270
  else:
271
  progress = False
272
 
273
+ assert optimizer_algorithm in ["NelderMead", "BFGS"]
274
  assert tournament_selection_n < npop
275
 
276
  if isinstance(X, pd.DataFrame):
 
281
  X = X[:, None]
282
 
283
  if len(variable_names) == 0:
284
+ variable_names = [f"x{i}" for i in range(X.shape[1])]
285
+
286
+ use_custom_variable_names = len(variable_names) != 0
287
+
288
+ _check_assertions(
289
+ X,
290
+ binary_operators,
291
+ unary_operators,
292
+ use_custom_variable_names,
293
+ variable_names,
294
+ weights,
295
+ y,
296
+ )
297
  _check_for_julia_installation()
298
 
 
299
  if len(X) > 10000 and not batching:
300
+ warnings.warn(
301
+ "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (https://pysr.readthedocs.io/en/latest/docs/options/#batching). You should also reconsider if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form with symbolic regression. More datapoints will lower the search speed."
302
+ )
303
 
304
  if maxsize > 40:
305
+ warnings.warn(
306
+ "Note: Using a large maxsize for the equation search will be exponentially slower and use significant memory. You should consider turning `useFrequency` to False, and perhaps use `warmupMaxsizeBy`."
307
+ )
308
 
309
  X, variable_names, selection = _handle_feature_selection(
310
+ X, select_k_features, use_custom_variable_names, variable_names, y
311
+ )
 
312
 
313
  if maxdepth is None:
314
  maxdepth = maxsize
 
327
  else:
328
  raise NotImplementedError("y shape not supported!")
329
 
330
+ kwargs = dict(
331
+ X=X,
332
+ y=y,
333
+ weights=weights,
334
+ alpha=alpha,
335
+ annealing=annealing,
336
+ batchSize=batchSize,
337
+ batching=batching,
338
+ binary_operators=binary_operators,
339
+ fast_cycle=fast_cycle,
340
+ fractionReplaced=fractionReplaced,
341
+ ncyclesperiteration=ncyclesperiteration,
342
+ niterations=niterations,
343
+ npop=npop,
344
+ topn=topn,
345
+ verbosity=verbosity,
346
+ progress=progress,
347
+ update=update,
348
+ julia_optimization=julia_optimization,
349
+ timeout=timeout,
350
+ fractionReplacedHof=fractionReplacedHof,
351
+ hofMigration=hofMigration,
352
+ maxdepth=maxdepth,
353
+ maxsize=maxsize,
354
+ migration=migration,
355
+ optimizer_algorithm=optimizer_algorithm,
356
+ optimizer_nrestarts=optimizer_nrestarts,
357
+ optimize_probability=optimize_probability,
358
+ optimizer_iterations=optimizer_iterations,
359
+ parsimony=parsimony,
360
+ perturbationFactor=perturbationFactor,
361
+ populations=populations,
362
+ procs=procs,
363
+ shouldOptimizeConstants=shouldOptimizeConstants,
364
+ unary_operators=unary_operators,
365
+ useFrequency=useFrequency,
366
+ use_custom_variable_names=use_custom_variable_names,
367
+ variable_names=variable_names,
368
+ warmupMaxsizeBy=warmupMaxsizeBy,
369
+ weightAddNode=weightAddNode,
370
+ weightDeleteNode=weightDeleteNode,
371
+ weightDoNothing=weightDoNothing,
372
+ weightInsertNode=weightInsertNode,
373
+ weightMutateConstant=weightMutateConstant,
374
+ weightMutateOperator=weightMutateOperator,
375
+ weightRandomize=weightRandomize,
376
+ weightSimplify=weightSimplify,
377
+ constraints=constraints,
378
+ extra_sympy_mappings=extra_sympy_mappings,
379
+ extra_jax_mappings=extra_jax_mappings,
380
+ extra_torch_mappings=extra_torch_mappings,
381
+ julia_project=julia_project,
382
+ loss=loss,
383
+ output_jax_format=output_jax_format,
384
+ output_torch_format=output_torch_format,
385
+ selection=selection,
386
+ multioutput=multioutput,
387
+ nout=nout,
388
+ tournament_selection_n=tournament_selection_n,
389
+ tournament_selection_p=tournament_selection_p,
390
+ )
391
 
392
  kwargs = {**_set_paths(tempdir), **kwargs}
393
 
394
  if temp_equation_file:
395
+ equation_file = kwargs["tmpdir"] / f"hall_of_fame.csv"
396
  elif equation_file is None:
397
  date_time = datetime.now().strftime("%Y-%m-%d_%H%M%S.%f")[:-3]
398
+ equation_file = "hall_of_fame_" + date_time + ".csv"
399
 
400
  kwargs = {**dict(equation_file=equation_file), **kwargs}
401
 
402
+ pkg_directory = kwargs["pkg_directory"]
 
403
  manifest_file = None
404
+ if kwargs["julia_project"] is not None:
405
+ manifest_filepath = Path(kwargs["julia_project"]) / "Manifest.toml"
406
  else:
407
+ manifest_filepath = pkg_directory / "Manifest.toml"
408
 
409
+ kwargs["need_install"] = False
410
 
411
  if not (manifest_filepath).is_file():
412
+ kwargs["need_install"] = (not user_input) or _yesno(
413
+ "I will install Julia packages using PySR's Project.toml file. OK?"
414
+ )
415
+ if kwargs["need_install"]:
416
  print("OK. I will install at launch.")
417
  assert update
418
 
419
+ kwargs["def_hyperparams"] = _create_inline_operators(**kwargs)
420
 
421
  _handle_constraints(**kwargs)
422
 
423
+ kwargs["constraints_str"] = _make_constraints_str(**kwargs)
424
+ kwargs["def_hyperparams"] = _make_hyperparams_julia_str(**kwargs)
425
+ kwargs["def_datasets"] = _make_datasets_julia_str(**kwargs)
426
 
427
  _create_julia_files(**kwargs)
428
  _final_pysr_process(**kwargs)
 
431
  equations = get_hof(**kwargs)
432
 
433
  if delete_tempfiles:
434
+ shutil.rmtree(kwargs["tmpdir"])
435
 
436
  return equations
437
 
 
439
  def _set_globals(X, **kwargs):
440
  global global_state
441
 
442
+ global_state["n_features"] = X.shape[1]
443
  for key, value in kwargs.items():
444
  if key in global_state:
445
  global_state[key] = value
 
447
 
448
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
449
  command = [
450
+ f"julia",
451
+ f"-O{julia_optimization:d}",
452
  str(runfile_filename),
453
  ]
454
  if timeout is not None:
455
+ command = [f"timeout", f"{timeout}"] + command
456
  _cmd_runner(command, **kwargs)
457
 
458
+
459
  def _cmd_runner(command, progress, **kwargs):
460
+ if kwargs["verbosity"] > 0:
461
+ print("Running on", " ".join(command))
462
  process = subprocess.Popen(command, stdout=subprocess.PIPE, bufsize=-1)
463
  try:
464
  while True:
465
  line = process.stdout.readline()
466
+ if not line:
467
+ break
468
+ decoded_line = line.decode("utf-8")
469
  if progress:
470
+ decoded_line = (
471
+ decoded_line.replace("\\033[K", "\033[K")
472
+ .replace("\\033[1A", "\033[1A")
473
+ .replace("\\033[1B", "\033[1B")
474
+ .replace("\\r", "\r")
475
+ .encode(sys.stdout.encoding, errors="replace")
476
+ )
477
  sys.stdout.buffer.write(decoded_line)
478
  sys.stdout.flush()
479
  else:
480
+ print(decoded_line, end="")
481
 
482
  process.stdout.close()
483
  process.wait()
 
485
  print("Killing process... will return when done.")
486
  process.kill()
487
 
488
+
489
+ def _create_julia_files(
490
+ dataset_filename,
491
+ def_datasets,
492
+ hyperparam_filename,
493
+ def_hyperparams,
494
+ fractionReplaced,
495
+ ncyclesperiteration,
496
+ niterations,
497
+ npop,
498
+ runfile_filename,
499
+ topn,
500
+ verbosity,
501
+ julia_project,
502
+ procs,
503
+ weights,
504
+ X,
505
+ variable_names,
506
+ pkg_directory,
507
+ need_install,
508
+ update,
509
+ **kwargs,
510
+ ):
511
+ with open(hyperparam_filename, "w") as f:
512
  print(def_hyperparams, file=f)
513
+ with open(dataset_filename, "w") as f:
514
  print(def_datasets, file=f)
515
+ with open(runfile_filename, "w") as f:
516
  if julia_project is None:
517
  julia_project = pkg_directory
518
  else:
519
  julia_project = Path(julia_project)
520
+ print(f"import Pkg", file=f)
521
  print(f'Pkg.activate("{_escape_filename(julia_project)}")', file=f)
522
  if need_install:
523
+ print(f"Pkg.instantiate()", file=f)
524
+ print(f"Pkg.update()", file=f)
525
+ print(f"Pkg.precompile()", file=f)
526
  elif update:
527
+ print(f"Pkg.update()", file=f)
528
+ print(f"using SymbolicRegression", file=f)
529
  print(f'include("{_escape_filename(hyperparam_filename)}")', file=f)
530
  print(f'include("{_escape_filename(dataset_filename)}")', file=f)
531
  if len(variable_names) == 0:
532
  varMap = "[" + ",".join([f'"x{i}"' for i in range(X.shape[1])]) + "]"
533
  else:
534
+ varMap = (
535
+ "[" + ",".join(['"' + vname + '"' for vname in variable_names]) + "]"
536
+ )
537
 
538
  if weights is not None:
539
+ print(
540
+ f"EquationSearch(X, y, weights=weights, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})",
541
+ file=f,
542
+ )
543
  else:
544
+ print(
545
+ f"EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})",
546
+ file=f,
547
+ )
548
 
549
 
550
+ def _make_datasets_julia_str(
551
+ X, X_filename, weights, weights_filename, y, y_filename, multioutput, **kwargs
552
+ ):
553
  def_datasets = """using DelimitedFiles"""
554
+ np.savetxt(X_filename, X.astype(np.float32), delimiter=",")
555
  if multioutput:
556
+ np.savetxt(y_filename, y.astype(np.float32), delimiter=",")
557
  else:
558
+ np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=",")
559
  if weights is not None:
560
  if multioutput:
561
+ np.savetxt(weights_filename, weights.astype(np.float32), delimiter=",")
562
  else:
563
+ np.savetxt(
564
+ weights_filename,
565
+ weights.reshape(-1, 1).astype(np.float32),
566
+ delimiter=",",
567
+ )
568
  def_datasets += f"""
569
  X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
570
 
571
  if multioutput:
572
+ def_datasets += f"""
573
  y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
574
  else:
575
+ def_datasets += f"""
576
  y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
577
 
578
  if weights is not None:
 
584
  weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
585
  return def_datasets
586
 
587
+
588
+ def _make_hyperparams_julia_str(
589
+ X,
590
+ alpha,
591
+ annealing,
592
+ batchSize,
593
+ batching,
594
+ binary_operators,
595
+ constraints_str,
596
+ def_hyperparams,
597
+ equation_file,
598
+ fast_cycle,
599
+ fractionReplacedHof,
600
+ hofMigration,
601
+ maxdepth,
602
+ maxsize,
603
+ migration,
604
+ optimizer_algorithm,
605
+ optimizer_nrestarts,
606
+ optimize_probability,
607
+ optimizer_iterations,
608
+ npop,
609
+ parsimony,
610
+ perturbationFactor,
611
+ populations,
612
+ procs,
613
+ shouldOptimizeConstants,
614
+ unary_operators,
615
+ useFrequency,
616
+ use_custom_variable_names,
617
+ variable_names,
618
+ warmupMaxsizeBy,
619
+ weightAddNode,
620
+ ncyclesperiteration,
621
+ fractionReplaced,
622
+ topn,
623
+ verbosity,
624
+ progress,
625
+ loss,
626
+ weightDeleteNode,
627
+ weightDoNothing,
628
+ weightInsertNode,
629
+ weightMutateConstant,
630
+ weightMutateOperator,
631
+ weightRandomize,
632
+ weightSimplify,
633
+ weights,
634
+ tournament_selection_n,
635
+ tournament_selection_p,
636
+ **kwargs,
637
+ ):
638
  try:
639
  term_width = shutil.get_terminal_size().columns
640
  except:
641
+ _, term_width = subprocess.check_output(["stty", "size"]).split()
642
+
643
  def tuple_fix(ops):
644
  if len(ops) > 1:
645
+ return ", ".join(ops)
646
  elif len(ops) == 0:
647
+ return ""
648
  else:
649
+ return ops[0] + ","
650
 
651
  def_hyperparams += f"""\n
652
  plus=(+)
 
716
  terminal_width={term_width:d}
717
  """
718
 
719
+ def_hyperparams += "\n)"
720
  return def_hyperparams
721
 
722
 
 
749
  for op in binary_operators:
750
  if op not in constraints:
751
  constraints[op] = (-1, -1)
752
+ if op in ["plus", "sub"]:
753
  if constraints[op][0] != constraints[op][1]:
754
  raise NotImplementedError(
755
+ "You need equal constraints on both sides for - and *, due to simplification strategies."
756
+ )
757
+ elif op == "mult":
758
  # Make sure the complex expression is in the left side.
759
  if constraints[op][0] == -1:
760
  continue
761
  elif constraints[op][1] == -1 or constraints[op][0] < constraints[op][1]:
762
+ constraints[op][0], constraints[op][1] = (
763
+ constraints[op][1],
764
+ constraints[op][0],
765
+ )
766
 
767
 
768
  def _create_inline_operators(binary_operators, unary_operators, **kwargs):
 
770
  for op_list in [binary_operators, unary_operators]:
771
  for i in range(len(op_list)):
772
  op = op_list[i]
773
+ is_user_defined_operator = "(" in op
774
 
775
  if is_user_defined_operator:
776
  def_hyperparams += op + "\n"
777
  # Cut off from the first non-alphanumeric char:
778
  first_non_char = [
779
+ j
780
+ for j in range(len(op))
781
+ if not (op[j].isalpha() or op[j].isdigit())
782
+ ][0]
783
  function_name = op[:first_non_char]
784
  op_list[i] = function_name
785
  return def_hyperparams
786
 
787
 
788
+ def _handle_feature_selection(
789
+ X, select_k_features, use_custom_variable_names, variable_names, y
790
+ ):
791
  if select_k_features is not None:
792
  selection = run_feature_selection(X, y, select_k_features)
793
  print(f"Using features {selection}")
794
  X = X[:, selection]
795
 
796
  if use_custom_variable_names:
797
+ variable_names = [
798
+ variable_names[selection[i]] for i in range(len(selection))
799
+ ]
800
  else:
801
  selection = None
802
  return X, variable_names, selection
 
807
  pkg_directory = Path(__file__).parents[1]
808
  default_project_file = pkg_directory / "Project.toml"
809
  tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
810
+ hyperparam_filename = tmpdir / f"hyperparams.jl"
811
+ dataset_filename = tmpdir / f"dataset.jl"
812
+ runfile_filename = tmpdir / f"runfile.jl"
813
  X_filename = tmpdir / "X.csv"
814
  y_filename = tmpdir / "y.csv"
815
  weights_filename = tmpdir / "weights.csv"
816
+ return dict(
817
+ pkg_directory=pkg_directory,
818
+ default_project_file=default_project_file,
819
+ X_filename=X_filename,
820
+ dataset_filename=dataset_filename,
821
+ hyperparam_filename=hyperparam_filename,
822
+ runfile_filename=runfile_filename,
823
+ tmpdir=tmpdir,
824
+ weights_filename=weights_filename,
825
+ y_filename=y_filename,
826
+ )
827
+
828
+
829
+ def _check_assertions(
830
+ X,
831
+ binary_operators,
832
+ unary_operators,
833
+ use_custom_variable_names,
834
+ variable_names,
835
+ weights,
836
+ y,
837
+ ):
838
  # Check for potential errors before they happen
839
  assert len(unary_operators) + len(binary_operators) > 0
840
  assert len(X.shape) == 2
 
846
  if use_custom_variable_names:
847
  assert len(variable_names) == X.shape[1]
848
 
849
+
850
  def _check_for_julia_installation():
851
  try:
852
  process = subprocess.Popen(["julia", "-v"], stdout=subprocess.PIPE, bufsize=-1)
853
  while True:
854
  line = process.stdout.readline()
855
+ if not line:
856
+ break
857
  process.stdout.close()
858
  process.wait()
859
  except FileNotFoundError:
860
  import os
861
+
862
+ raise RuntimeError(
863
+ f"Your current $PATH is: {os.environ['PATH']}\nPySR could not start julia. Make sure julia is installed and on your $PATH."
864
+ )
865
  process.kill()
866
 
867
 
868
  def run_feature_selection(X, y, select_k_features):
869
  """Use a gradient boosting tree regressor as a proxy for finding
870
+ the k most important features in X, returning indices for those
871
+ features as output."""
872
 
873
  from sklearn.ensemble import RandomForestRegressor
874
  from sklearn.feature_selection import SelectFromModel, SelectKBest
875
 
876
  clf = RandomForestRegressor(n_estimators=100, max_depth=3, random_state=0)
877
  clf.fit(X, y)
878
+ selector = SelectFromModel(
879
+ clf, threshold=-np.inf, max_features=select_k_features, prefit=True
880
+ )
881
  return selector.get_support(indices=True)
882
 
883
+
884
+ def get_hof(
885
+ equation_file=None,
886
+ n_features=None,
887
+ variable_names=None,
888
+ output_jax_format=None,
889
+ output_torch_format=None,
890
+ selection=None,
891
+ extra_sympy_mappings=None,
892
+ extra_jax_mappings=None,
893
+ extra_torch_mappings=None,
894
+ multioutput=None,
895
+ nout=None,
896
+ **kwargs,
897
+ ):
898
  """Get the equations from a hall of fame file. If no arguments
899
  entered, the ones used previously from a call to PySR will be used."""
900
 
901
  global global_state
902
 
903
+ if equation_file is None:
904
+ equation_file = global_state["equation_file"]
905
+ if n_features is None:
906
+ n_features = global_state["n_features"]
907
+ if variable_names is None:
908
+ variable_names = global_state["variable_names"]
909
+ if extra_sympy_mappings is None:
910
+ extra_sympy_mappings = global_state["extra_sympy_mappings"]
911
+ if extra_jax_mappings is None:
912
+ extra_jax_mappings = global_state["extra_jax_mappings"]
913
+ if extra_torch_mappings is None:
914
+ extra_torch_mappings = global_state["extra_torch_mappings"]
915
+ if output_torch_format is None:
916
+ output_torch_format = global_state["output_torch_format"]
917
+ if output_jax_format is None:
918
+ output_jax_format = global_state["output_jax_format"]
919
+ if multioutput is None:
920
+ multioutput = global_state["multioutput"]
921
+ if nout is None:
922
+ nout = global_state["nout"]
923
+ if selection is None:
924
+ selection = global_state["selection"]
925
+
926
+ global_state["selection"] = selection
927
+ global_state["equation_file"] = equation_file
928
+ global_state["n_features"] = n_features
929
+ global_state["variable_names"] = variable_names
930
+ global_state["extra_sympy_mappings"] = extra_sympy_mappings
931
+ global_state["extra_jax_mappings"] = extra_jax_mappings
932
+ global_state["extra_torch_mappings"] = extra_torch_mappings
933
+ global_state["output_torch_format"] = output_torch_format
934
+ global_state["output_jax_format"] = output_jax_format
935
+ global_state["multioutput"] = multioutput
936
+ global_state["nout"] = nout
937
+ global_state["selection"] = selection
938
 
939
  try:
940
  if multioutput:
941
+ all_outputs = [
942
+ pd.read_csv(str(equation_file) + f".out{i}" + ".bkup", sep="|")
943
+ for i in range(1, nout + 1)
944
+ ]
945
  else:
946
+ all_outputs = [pd.read_csv(str(equation_file) + ".bkup", sep="|")]
947
  except FileNotFoundError:
948
+ raise RuntimeError(
949
+ "Couldn't find equation file! The equation search likely exited before a single iteration completed."
950
+ )
951
 
952
  ret_outputs = []
953
 
 
962
  jax_format = []
963
  if output_torch_format:
964
  torch_format = []
965
+ use_custom_variable_names = len(variable_names) != 0
966
+ local_sympy_mappings = {**extra_sympy_mappings, **sympy_mappings}
 
 
 
967
 
968
  if use_custom_variable_names:
969
  sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
970
  else:
971
+ sympy_symbols = [sympy.Symbol("x%d" % i) for i in range(n_features)]
972
 
973
  for i in range(len(output)):
974
+ eqn = sympify(output.loc[i, "Equation"], locals=local_sympy_mappings)
975
  sympy_format.append(eqn)
976
 
977
  # Numpy:
 
980
  # JAX:
981
  if output_jax_format:
982
  from .export_jax import sympy2jax
983
+
984
  func, params = sympy2jax(eqn, sympy_symbols, selection)
985
+ jax_format.append({"callable": func, "parameters": params})
986
 
987
  # Torch:
988
  if output_torch_format:
989
  from .export_torch import sympy2torch
990
+
991
  module = sympy2torch(eqn, sympy_symbols, selection=selection)
992
  torch_format.append(module)
993
 
994
+ curMSE = output.loc[i, "MSE"]
995
+ curComplexity = output.loc[i, "Complexity"]
996
 
997
  if lastMSE is None:
998
  cur_score = 0.0
999
  else:
1000
+ cur_score = -np.log(curMSE / lastMSE) / (curComplexity - lastComplexity)
1001
 
1002
  scores.append(cur_score)
1003
  lastMSE = curMSE
1004
  lastComplexity = curComplexity
1005
 
1006
+ output["score"] = np.array(scores)
1007
+ output["sympy_format"] = sympy_format
1008
+ output["lambda_format"] = lambda_format
1009
+ output_cols = [
1010
+ "Complexity",
1011
+ "MSE",
1012
+ "score",
1013
+ "Equation",
1014
+ "sympy_format",
1015
+ "lambda_format",
1016
+ ]
1017
  if output_jax_format:
1018
+ output_cols += ["jax_format"]
1019
+ output["jax_format"] = jax_format
1020
  if output_torch_format:
1021
+ output_cols += ["torch_format"]
1022
+ output["torch_format"] = torch_format
1023
 
1024
  ret_outputs.append(output[output_cols])
1025
 
 
1028
  else:
1029
  return ret_outputs[0]
1030
 
1031
+
1032
  def best_row(equations=None):
1033
  """Return the best row of a hall of fame file using the score column.
1034
  By default this uses the last equation file.
1035
  """
1036
+ if equations is None:
1037
+ equations = get_hof()
1038
  if isinstance(equations, list):
1039
+ return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
1040
  else:
1041
+ return equations.iloc[np.argmax(equations["score"])]
1042
+
1043
 
1044
  def best_tex(equations=None):
1045
  """Return the equation with the best score, in latex format
1046
  By default this uses the last equation file.
1047
  """
1048
+ if equations is None:
1049
+ equations = get_hof()
1050
  if isinstance(equations, list):
1051
+ return [
1052
+ sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
1053
+ ]
1054
  else:
1055
+ return sympy.latex(best_row(equations)["sympy_format"].simplify())
1056
+
1057
 
1058
  def best(equations=None):
1059
  """Return the equation with the best score, in sympy format.
1060
  By default this uses the last equation file.
1061
  """
1062
+ if equations is None:
1063
+ equations = get_hof()
1064
  if isinstance(equations, list):
1065
+ return [best_row(eq)["sympy_format"].simplify() for eq in equations]
1066
  else:
1067
+ return best_row(equations)["sympy_format"].simplify()
1068
+
1069
 
1070
  def best_callable(equations=None):
1071
  """Return the equation with the best score, in callable format.
1072
  By default this uses the last equation file.
1073
  """
1074
+ if equations is None:
1075
+ equations = get_hof()
1076
  if isinstance(equations, list):
1077
+ return [best_row(eq)["lambda_format"] for eq in equations]
1078
  else:
1079
+ return best_row(equations)["lambda_format"]
1080
+
1081
 
1082
  def _escape_filename(filename):
1083
  """Turns a file into a string representation with correctly escaped backslashes"""
1084
  repr = str(filename)
1085
+ repr = repr.replace("\\", "\\\\")
1086
  return repr
1087
 
1088
+
1089
  # https://gist.github.com/garrettdreyfus/8153571
1090
  def _yesno(question):
1091
  """Simple Yes/No Function."""
1092
+ prompt = f"{question} (y/n): "
1093
  ans = input(prompt).strip().lower()
1094
+ if ans not in ["y", "n"]:
1095
+ print(f"{ans} is invalid, please try again...")
1096
  return _yesno(question)
1097
+ if ans == "y":
1098
  return True
1099
  return False
1100
 
1101
 
1102
  class CallableEquation(object):
1103
  """Simple wrapper for numpy lambda functions built with sympy"""
1104
+
1105
  def __init__(self, sympy_symbols, eqn, selection=None):
1106
  self._sympy = eqn
1107
  self._sympy_symbols = sympy_symbols
 
1116
  return self._lambda(*X[:, self._selection].T)
1117
  else:
1118
  return self._lambda(*X.T)
 
setup.py CHANGED
@@ -12,19 +12,13 @@ setuptools.setup(
12
  long_description=long_description,
13
  long_description_content_type="text/markdown",
14
  url="https://github.com/MilesCranmer/pysr",
15
- install_requires=[
16
- "numpy",
17
- "pandas",
18
- "sympy"
19
- ],
20
  packages=setuptools.find_packages(),
21
- package_data={
22
- 'pysr': ['../Project.toml', '../datasets/*']
23
- },
24
  include_package_data=False,
25
  classifiers=[
26
  "Programming Language :: Python :: 3",
27
  "Operating System :: OS Independent",
28
  ],
29
- python_requires='>=3.7',
30
  )
 
12
  long_description=long_description,
13
  long_description_content_type="text/markdown",
14
  url="https://github.com/MilesCranmer/pysr",
15
+ install_requires=["numpy", "pandas", "sympy"],
 
 
 
 
16
  packages=setuptools.find_packages(),
17
+ package_data={"pysr": ["../Project.toml", "../datasets/*"]},
 
 
18
  include_package_data=False,
19
  classifiers=[
20
  "Programming Language :: Python :: 3",
21
  "Operating System :: OS Independent",
22
  ],
23
+ python_requires=">=3.7",
24
  )
test/test.py CHANGED
@@ -6,6 +6,7 @@ import sympy
6
  from sympy import lambdify
7
  import pandas as pd
8
 
 
9
  class TestPipeline(unittest.TestCase):
10
  def setUp(self):
11
  self.default_test_kwargs = dict(
@@ -17,86 +18,105 @@ class TestPipeline(unittest.TestCase):
17
  )
18
  np.random.seed(0)
19
  self.X = np.random.randn(100, 5)
20
-
21
  def test_linear_relation(self):
22
  y = self.X[:, 0]
23
  equations = pysr(self.X, y, **self.default_test_kwargs)
24
  print(equations)
25
- self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
26
 
27
  def test_multioutput_custom_operator(self):
28
- y = self.X[:, [0, 1]]**2
29
- equations = pysr(self.X, y,
30
- unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
31
- extra_sympy_mappings={'sq': lambda x: x**2},
32
- **self.default_test_kwargs,
33
- procs=0)
 
 
 
 
34
  print(equations)
35
- self.assertLessEqual(equations[0].iloc[-1]['MSE'], 1e-4)
36
- self.assertLessEqual(equations[1].iloc[-1]['MSE'], 1e-4)
37
 
38
  def test_multioutput_weighted_with_callable(self):
39
- y = self.X[:, [0, 1]]**2
40
  w = np.random.rand(*y.shape)
41
  w[w < 0.5] = 0.0
42
  w[w >= 0.5] = 1.0
43
 
44
  # Double equation when weights are 0:
45
- y += (1-w) * y
46
  # Thus, pysr needs to use the weights to find the right equation!
47
 
48
- equations = pysr(self.X, y, weights=w,
49
- unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
50
- extra_sympy_mappings={'sq': lambda x: x**2},
51
- **self.default_test_kwargs,
52
- procs=0)
 
 
 
 
 
53
 
54
  np.testing.assert_almost_equal(
55
- best_callable()[0](self.X),
56
- self.X[:, 0]**2,
57
- decimal=4)
58
  np.testing.assert_almost_equal(
59
- best_callable()[1](self.X),
60
- self.X[:, 1]**2,
61
- decimal=4)
62
 
63
  def test_empty_operators_single_input(self):
64
  X = np.random.randn(100, 1)
65
  y = X[:, 0] + 3.0
66
- equations = pysr(X, y,
67
- unary_operators=[], binary_operators=["plus"],
68
- **self.default_test_kwargs)
 
 
 
 
 
 
69
 
70
- self.assertLessEqual(equations.iloc[-1]['MSE'], 1e-4)
71
 
72
  class TestBest(unittest.TestCase):
73
  def setUp(self):
74
- equations = pd.DataFrame({
75
- 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
76
- 'MSE': [1.0, 0.1, 1e-5],
77
- 'Complexity': [1, 2, 3]
78
- })
 
 
79
 
80
- equations['Complexity MSE Equation'.split(' ')].to_csv(
81
- 'equation_file.csv.bkup', sep='|')
 
82
 
83
  self.equations = get_hof(
84
- 'equation_file.csv', n_features=2,
85
- variables_names='x0 x1'.split(' '),
86
- extra_sympy_mappings={}, output_jax_format=False,
87
- multioutput=False, nout=1)
 
 
 
 
88
 
89
  def test_best(self):
90
- self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol('x0'))**2)
91
- self.assertEqual(best(), sympy.cos(sympy.Symbol('x0'))**2)
92
 
93
  def test_best_tex(self):
94
- self.assertEqual(best_tex(self.equations), '\\cos^{2}{\\left(x_{0} \\right)}')
95
- self.assertEqual(best_tex(), '\\cos^{2}{\\left(x_{0} \\right)}')
96
 
97
  def test_best_lambda(self):
98
  X = np.random.randn(10, 2)
99
- y = np.cos(X[:, 0])**2
100
  for f in [best_callable(), best_callable(self.equations)]:
101
  np.testing.assert_almost_equal(f(X), y, decimal=4)
102
 
@@ -107,22 +127,23 @@ class TestFeatureSelection(unittest.TestCase):
107
 
108
  def test_feature_selection(self):
109
  X = np.random.randn(20000, 5)
110
- y = X[:, 2]**2 + X[:, 3]**2
111
  selected = run_feature_selection(X, y, select_k_features=2)
112
  self.assertEqual(sorted(selected), [2, 3])
113
 
114
  def test_feature_selection_handler(self):
115
  X = np.random.randn(20000, 5)
116
- y = X[:, 2]**2 + X[:, 3]**2
117
- var_names = [f'x{i}' for i in range(5)]
118
  selected_X, selected_var_names, selection = _handle_feature_selection(
119
- X, select_k_features=2,
120
- use_custom_variable_names=True,
121
- variable_names=[f'x{i}' for i in range(5)],
122
- y=y)
 
 
123
  self.assertTrue((2 in selection) and (3 in selection))
124
- self.assertEqual(set(selected_var_names), set('x2 x3'.split(' ')))
125
  np.testing.assert_array_equal(
126
- np.sort(selected_X, axis=1),
127
- np.sort(X[:, [2, 3]], axis=1)
128
- )
 
6
  from sympy import lambdify
7
  import pandas as pd
8
 
9
+
10
  class TestPipeline(unittest.TestCase):
11
  def setUp(self):
12
  self.default_test_kwargs = dict(
 
18
  )
19
  np.random.seed(0)
20
  self.X = np.random.randn(100, 5)
21
+
22
  def test_linear_relation(self):
23
  y = self.X[:, 0]
24
  equations = pysr(self.X, y, **self.default_test_kwargs)
25
  print(equations)
26
+ self.assertLessEqual(equations.iloc[-1]["MSE"], 1e-4)
27
 
28
  def test_multioutput_custom_operator(self):
29
+ y = self.X[:, [0, 1]] ** 2
30
+ equations = pysr(
31
+ self.X,
32
+ y,
33
+ unary_operators=["sq(x) = x^2"],
34
+ binary_operators=["plus"],
35
+ extra_sympy_mappings={"sq": lambda x: x ** 2},
36
+ **self.default_test_kwargs,
37
+ procs=0,
38
+ )
39
  print(equations)
40
+ self.assertLessEqual(equations[0].iloc[-1]["MSE"], 1e-4)
41
+ self.assertLessEqual(equations[1].iloc[-1]["MSE"], 1e-4)
42
 
43
  def test_multioutput_weighted_with_callable(self):
44
+ y = self.X[:, [0, 1]] ** 2
45
  w = np.random.rand(*y.shape)
46
  w[w < 0.5] = 0.0
47
  w[w >= 0.5] = 1.0
48
 
49
  # Double equation when weights are 0:
50
+ y += (1 - w) * y
51
  # Thus, pysr needs to use the weights to find the right equation!
52
 
53
+ equations = pysr(
54
+ self.X,
55
+ y,
56
+ weights=w,
57
+ unary_operators=["sq(x) = x^2"],
58
+ binary_operators=["plus"],
59
+ extra_sympy_mappings={"sq": lambda x: x ** 2},
60
+ **self.default_test_kwargs,
61
+ procs=0,
62
+ )
63
 
64
  np.testing.assert_almost_equal(
65
+ best_callable()[0](self.X), self.X[:, 0] ** 2, decimal=4
66
+ )
 
67
  np.testing.assert_almost_equal(
68
+ best_callable()[1](self.X), self.X[:, 1] ** 2, decimal=4
69
+ )
 
70
 
71
  def test_empty_operators_single_input(self):
72
  X = np.random.randn(100, 1)
73
  y = X[:, 0] + 3.0
74
+ equations = pysr(
75
+ X,
76
+ y,
77
+ unary_operators=[],
78
+ binary_operators=["plus"],
79
+ **self.default_test_kwargs,
80
+ )
81
+
82
+ self.assertLessEqual(equations.iloc[-1]["MSE"], 1e-4)
83
 
 
84
 
85
  class TestBest(unittest.TestCase):
86
  def setUp(self):
87
+ equations = pd.DataFrame(
88
+ {
89
+ "Equation": ["1.0", "cos(x0)", "square(cos(x0))"],
90
+ "MSE": [1.0, 0.1, 1e-5],
91
+ "Complexity": [1, 2, 3],
92
+ }
93
+ )
94
 
95
+ equations["Complexity MSE Equation".split(" ")].to_csv(
96
+ "equation_file.csv.bkup", sep="|"
97
+ )
98
 
99
  self.equations = get_hof(
100
+ "equation_file.csv",
101
+ n_features=2,
102
+ variables_names="x0 x1".split(" "),
103
+ extra_sympy_mappings={},
104
+ output_jax_format=False,
105
+ multioutput=False,
106
+ nout=1,
107
+ )
108
 
109
  def test_best(self):
110
+ self.assertEqual(best(self.equations), sympy.cos(sympy.Symbol("x0")) ** 2)
111
+ self.assertEqual(best(), sympy.cos(sympy.Symbol("x0")) ** 2)
112
 
113
  def test_best_tex(self):
114
+ self.assertEqual(best_tex(self.equations), "\\cos^{2}{\\left(x_{0} \\right)}")
115
+ self.assertEqual(best_tex(), "\\cos^{2}{\\left(x_{0} \\right)}")
116
 
117
  def test_best_lambda(self):
118
  X = np.random.randn(10, 2)
119
+ y = np.cos(X[:, 0]) ** 2
120
  for f in [best_callable(), best_callable(self.equations)]:
121
  np.testing.assert_almost_equal(f(X), y, decimal=4)
122
 
 
127
 
128
  def test_feature_selection(self):
129
  X = np.random.randn(20000, 5)
130
+ y = X[:, 2] ** 2 + X[:, 3] ** 2
131
  selected = run_feature_selection(X, y, select_k_features=2)
132
  self.assertEqual(sorted(selected), [2, 3])
133
 
134
  def test_feature_selection_handler(self):
135
  X = np.random.randn(20000, 5)
136
+ y = X[:, 2] ** 2 + X[:, 3] ** 2
137
+ var_names = [f"x{i}" for i in range(5)]
138
  selected_X, selected_var_names, selection = _handle_feature_selection(
139
+ X,
140
+ select_k_features=2,
141
+ use_custom_variable_names=True,
142
+ variable_names=[f"x{i}" for i in range(5)],
143
+ y=y,
144
+ )
145
  self.assertTrue((2 in selection) and (3 in selection))
146
+ self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
147
  np.testing.assert_array_equal(
148
+ np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)
149
+ )
 
test/test_jax.py CHANGED
@@ -7,37 +7,48 @@ from jax import random
7
  from jax import grad
8
  import sympy
9
 
 
10
  class TestJAX(unittest.TestCase):
11
  def setUp(self):
12
  np.random.seed(0)
13
 
14
  def test_sympy2jax(self):
15
- x, y, z = sympy.symbols('x y z')
16
  cosx = 1.0 * sympy.cos(x) + y
17
  key = random.PRNGKey(0)
18
  X = random.normal(key, (1000, 2))
19
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
20
  f, params = sympy2jax(cosx, [x, y, z])
21
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
 
22
  def test_pipeline(self):
23
  X = np.random.randn(100, 10)
24
- equations = pd.DataFrame({
25
- 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
26
- 'MSE': [1.0, 0.1, 1e-5],
27
- 'Complexity': [1, 2, 3]
28
- })
 
 
29
 
30
- equations['Complexity MSE Equation'.split(' ')].to_csv(
31
- 'equation_file.csv.bkup', sep='|')
 
32
 
33
  equations = get_hof(
34
- 'equation_file.csv', n_features=2, variables_names='x1 x2 x3'.split(' '),
35
- extra_sympy_mappings={}, output_jax_format=True,
36
- multioutput=False, nout=1, selection=[1, 2, 3])
 
 
 
 
 
 
37
 
38
  jformat = equations.iloc[-1].jax_format
39
  np.testing.assert_almost_equal(
40
- np.array(jformat['callable'](jnp.array(X), jformat['parameters'])),
41
- np.square(np.cos(X[:, 1])), # Select feature 1
42
- decimal=4
43
  )
 
7
  from jax import grad
8
  import sympy
9
 
10
+
11
  class TestJAX(unittest.TestCase):
12
  def setUp(self):
13
  np.random.seed(0)
14
 
15
  def test_sympy2jax(self):
16
+ x, y, z = sympy.symbols("x y z")
17
  cosx = 1.0 * sympy.cos(x) + y
18
  key = random.PRNGKey(0)
19
  X = random.normal(key, (1000, 2))
20
  true = 1.0 * jnp.cos(X[:, 0]) + X[:, 1]
21
  f, params = sympy2jax(cosx, [x, y, z])
22
  self.assertTrue(jnp.all(jnp.isclose(f(X, params), true)).item())
23
+
24
  def test_pipeline(self):
25
  X = np.random.randn(100, 10)
26
+ equations = pd.DataFrame(
27
+ {
28
+ "Equation": ["1.0", "cos(x0)", "square(cos(x0))"],
29
+ "MSE": [1.0, 0.1, 1e-5],
30
+ "Complexity": [1, 2, 3],
31
+ }
32
+ )
33
 
34
+ equations["Complexity MSE Equation".split(" ")].to_csv(
35
+ "equation_file.csv.bkup", sep="|"
36
+ )
37
 
38
  equations = get_hof(
39
+ "equation_file.csv",
40
+ n_features=2,
41
+ variables_names="x1 x2 x3".split(" "),
42
+ extra_sympy_mappings={},
43
+ output_jax_format=True,
44
+ multioutput=False,
45
+ nout=1,
46
+ selection=[1, 2, 3],
47
+ )
48
 
49
  jformat = equations.iloc[-1].jax_format
50
  np.testing.assert_almost_equal(
51
+ np.array(jformat["callable"](jnp.array(X), jformat["parameters"])),
52
+ np.square(np.cos(X[:, 1])), # Select feature 1
53
+ decimal=4,
54
  )
test/test_torch.py CHANGED
@@ -5,38 +5,49 @@ from pysr import sympy2torch, get_hof
5
  import torch
6
  import sympy
7
 
 
8
  class TestTorch(unittest.TestCase):
9
  def setUp(self):
10
  np.random.seed(0)
11
 
12
  def test_sympy2torch(self):
13
- x, y, z = sympy.symbols('x y z')
14
  cosx = 1.0 * sympy.cos(x) + y
15
  X = torch.tensor(np.random.randn(1000, 3))
16
  true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
17
  torch_module = sympy2torch(cosx, [x, y, z])
18
  self.assertTrue(
19
- np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
20
  )
 
21
  def test_pipeline(self):
22
  X = np.random.randn(100, 10)
23
- equations = pd.DataFrame({
24
- 'Equation': ['1.0', 'cos(x0)', 'square(cos(x0))'],
25
- 'MSE': [1.0, 0.1, 1e-5],
26
- 'Complexity': [1, 2, 3]
27
- })
 
 
28
 
29
- equations['Complexity MSE Equation'.split(' ')].to_csv(
30
- 'equation_file.csv.bkup', sep='|')
 
31
 
32
  equations = get_hof(
33
- 'equation_file.csv', n_features=2, variables_names='x1 x2 x3'.split(' '),
34
- extra_sympy_mappings={}, output_torch_format=True,
35
- multioutput=False, nout=1, selection=[1, 2, 3])
 
 
 
 
 
 
36
 
37
  tformat = equations.iloc[-1].torch_format
38
  np.testing.assert_almost_equal(
39
- tformat(torch.tensor(X)).detach().numpy(),
40
- np.square(np.cos(X[:, 1])), #Selection 1st feature
41
- decimal=4
42
  )
 
5
  import torch
6
  import sympy
7
 
8
+
9
  class TestTorch(unittest.TestCase):
10
  def setUp(self):
11
  np.random.seed(0)
12
 
13
  def test_sympy2torch(self):
14
+ x, y, z = sympy.symbols("x y z")
15
  cosx = 1.0 * sympy.cos(x) + y
16
  X = torch.tensor(np.random.randn(1000, 3))
17
  true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
18
  torch_module = sympy2torch(cosx, [x, y, z])
19
  self.assertTrue(
20
+ np.all(np.isclose(torch_module(X).detach().numpy(), true.detach().numpy()))
21
  )
22
+
23
  def test_pipeline(self):
24
  X = np.random.randn(100, 10)
25
+ equations = pd.DataFrame(
26
+ {
27
+ "Equation": ["1.0", "cos(x0)", "square(cos(x0))"],
28
+ "MSE": [1.0, 0.1, 1e-5],
29
+ "Complexity": [1, 2, 3],
30
+ }
31
+ )
32
 
33
+ equations["Complexity MSE Equation".split(" ")].to_csv(
34
+ "equation_file.csv.bkup", sep="|"
35
+ )
36
 
37
  equations = get_hof(
38
+ "equation_file.csv",
39
+ n_features=2,
40
+ variables_names="x1 x2 x3".split(" "),
41
+ extra_sympy_mappings={},
42
+ output_torch_format=True,
43
+ multioutput=False,
44
+ nout=1,
45
+ selection=[1, 2, 3],
46
+ )
47
 
48
  tformat = equations.iloc[-1].torch_format
49
  np.testing.assert_almost_equal(
50
+ tformat(torch.tensor(X)).detach().numpy(),
51
+ np.square(np.cos(X[:, 1])), # Selection 1st feature
52
+ decimal=4,
53
  )