MilesCranmer commited on
Commit
5cee3b5
1 Parent(s): 76e7a47

Store arguments in dictionary to prevent argument errors

Browse files
Files changed (1) hide show
  1. pysr/sr.py +70 -50
pysr/sr.py CHANGED
@@ -192,8 +192,6 @@ def pysr(X=None, y=None, weights=None,
192
 
193
  """
194
  raise_depreciation_errors(limitPowComplexity, threads)
195
- auxiliary_filename, X_filename, dataset_filename, hyperparam_filename, julia_auxiliary_filenames, operator_filename \
196
- ,pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename = set_paths(tempdir)
197
 
198
  if isinstance(X, pd.DataFrame):
199
  variable_names = list(X.columns)
@@ -204,9 +202,13 @@ def pysr(X=None, y=None, weights=None,
204
  if len(X.shape) == 1:
205
  X = X[:, None]
206
 
207
- check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y)
 
208
 
209
- X, variable_names = handle_feature_selection(X, select_k_features, use_custom_variable_names, variable_names, y)
 
 
 
210
 
211
  if maxdepth is None:
212
  maxdepth = maxsize
@@ -220,52 +222,66 @@ def pysr(X=None, y=None, weights=None,
220
  if isinstance(unary_operators, str):
221
  unary_operators = [unary_operators]
222
  if X is None:
223
- X, y = using_test_input(X, test, y)
224
-
225
- # Add pre-defined functions to Julia
226
- def_hyperparams = _predefined_function_addition(binary_operators, unary_operators)
227
-
228
- #arbitrary complexity by default
229
- _handle_constraints(binary_operators, constraints, unary_operators)
230
-
231
- constraints_str = _make_constraints_str(binary_operators, constraints, unary_operators)
232
-
233
-
234
- def_hyperparams = _make_hyperparams_julia_str(X, alpha, annealing, batchSize, batching, binary_operators,
235
- constraints_str, def_hyperparams, equation_file, fast_cycle,
236
- fractionReplacedHof, hofMigration, limitPowComplexity, maxdepth,
237
- maxsize, migration, nrestarts, operator_filename, parsimony,
238
- perturbationFactor, populations, procs, shouldOptimizeConstants,
239
- unary_operators, useFrequency, use_custom_variable_names, variable_names,
240
- warmupMaxsize, weightAddNode, weightDeleteNode, weightDoNothing,
241
- weightInsertNode, weightMutateConstant, weightMutateOperator,
242
- weightRandomize, weightSimplify, weights)
243
-
244
- def_auxiliary = _make_auxiliary_julia_str(julia_auxiliary_filenames)
245
-
246
- def_datasets = _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename)
247
-
248
- _create_julia_files(auxiliary_filename, dataset_filename, def_auxiliary, def_datasets, def_hyperparams, fractionReplaced, hyperparam_filename,
249
- ncyclesperiteration, niterations, npop, pkg_filename, runfile_filename, topn, verbosity)
250
-
251
- _final_pysr_process(julia_optimization, procs, runfile_filename, timeout)
252
-
253
- _set_globals(X, equation_file, extra_sympy_mappings, variable_names)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
 
255
  if delete_tempfiles:
256
- shutil.rmtree(tmpdir)
257
 
258
  return get_hof()
259
 
260
 
261
- def _make_auxiliary_julia_str(julia_auxiliary_filenames):
262
  def_auxiliary = '\n'.join([
263
  f"""include("{_escape_filename(aux_fname)}")""" for aux_fname in julia_auxiliary_filenames
264
  ])
265
  return def_auxiliary
266
 
267
 
268
- def _set_globals(X, equation_file, extra_sympy_mappings, variable_names):
269
  global global_n_features
270
  global global_equation_file
271
  global global_variable_names
@@ -276,7 +292,7 @@ def _set_globals(X, equation_file, extra_sympy_mappings, variable_names):
276
  global_extra_sympy_mappings = extra_sympy_mappings
277
 
278
 
279
- def _final_pysr_process(julia_optimization, procs, runfile_filename, timeout):
280
  command = [
281
  f'julia', f'-O{julia_optimization:d}',
282
  f'-p', f'{procs}',
@@ -300,7 +316,7 @@ def _final_pysr_process(julia_optimization, procs, runfile_filename, timeout):
300
 
301
 
302
  def _create_julia_files(auxiliary_filename, dataset_filename, def_auxiliary, def_datasets, def_hyperparams, fractionReplaced, hyperparam_filename,
303
- ncyclesperiteration, niterations, npop, pkg_filename, runfile_filename, topn, verbosity):
304
  with open(hyperparam_filename, 'w') as f:
305
  print(def_hyperparams, file=f)
306
  with open(dataset_filename, 'w') as f:
@@ -318,7 +334,7 @@ def _create_julia_files(auxiliary_filename, dataset_filename, def_auxiliary, def
318
  print(f'rmprocs(nprocs)', file=f)
319
 
320
 
321
- def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename):
322
  def_datasets = """using DelimitedFiles"""
323
  np.savetxt(X_filename, X, delimiter=',')
324
  np.savetxt(y_filename, y, delimiter=',')
@@ -339,7 +355,7 @@ def _make_hyperparams_julia_str(X, alpha, annealing, batchSize, batching, binary
339
  parsimony, perturbationFactor, populations, procs, shouldOptimizeConstants,
340
  unary_operators, useFrequency, use_custom_variable_names, variable_names, warmupMaxsize, weightAddNode,
341
  weightDeleteNode, weightDoNothing, weightInsertNode, weightMutateConstant,
342
- weightMutateOperator, weightRandomize, weightSimplify, weights):
343
  def_hyperparams += f"""include("{_escape_filename(operator_filename)}")
344
  {constraints_str}
345
  const binops = {'[' + ', '.join(binary_operators) + ']'}
@@ -418,7 +434,7 @@ end"""
418
  return def_hyperparams
419
 
420
 
421
- def _make_constraints_str(binary_operators, constraints, unary_operators):
422
  constraints_str = "const una_constraints = ["
423
  first = True
424
  for op in unary_operators:
@@ -440,7 +456,7 @@ const bin_constraints = ["""
440
  return constraints_str
441
 
442
 
443
- def _handle_constraints(binary_operators, constraints, unary_operators):
444
  for op in unary_operators:
445
  if op not in constraints:
446
  constraints[op] = -1
@@ -459,7 +475,7 @@ def _handle_constraints(binary_operators, constraints, unary_operators):
459
  constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
460
 
461
 
462
- def _predefined_function_addition(binary_operators, unary_operators):
463
  def_hyperparams = ""
464
  for op_list in [binary_operators, unary_operators]:
465
  for i in range(len(op_list)):
@@ -477,7 +493,7 @@ def _predefined_function_addition(binary_operators, unary_operators):
477
  return def_hyperparams
478
 
479
 
480
- def using_test_input(X, test, y):
481
  if test == 'simple1':
482
  eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
483
  elif test == 'simple2':
@@ -505,7 +521,7 @@ def handle_feature_selection(X, select_k_features, use_custom_variable_names, va
505
  return X, variable_names
506
 
507
 
508
- def set_paths(tempdir):
509
  # System-independent paths
510
  pkg_directory = Path(__file__).parents[1] / 'julia'
511
  pkg_filename = pkg_directory / "sr.jl"
@@ -523,7 +539,6 @@ def set_paths(tempdir):
523
  for fname in julia_auxiliaries
524
  ]
525
 
526
-
527
  tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
528
  hyperparam_filename = tmpdir / f'hyperparams.jl'
529
  dataset_filename = tmpdir / f'dataset.jl'
@@ -532,8 +547,13 @@ def set_paths(tempdir):
532
  X_filename = tmpdir / "X.csv"
533
  y_filename = tmpdir / "y.csv"
534
  weights_filename = tmpdir / "weights.csv"
535
- return auxiliary_filename, X_filename, dataset_filename, hyperparam_filename, julia_auxiliary_filenames, \
536
- operator_filename, pkg_filename, runfile_filename, tmpdir, weights_filename, y_filename
 
 
 
 
 
537
 
538
 
539
  def check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):
 
192
 
193
  """
194
  raise_depreciation_errors(limitPowComplexity, threads)
 
 
195
 
196
  if isinstance(X, pd.DataFrame):
197
  variable_names = list(X.columns)
 
202
  if len(X.shape) == 1:
203
  X = X[:, None]
204
 
205
+ check_assertions(X, binary_operators, unary_operators,
206
+ use_custom_variable_names, variable_names, weights, y)
207
 
208
+ X, variable_names = handle_feature_selection(
209
+ X, select_k_features,
210
+ use_custom_variable_names, variable_names, y
211
+ )
212
 
213
  if maxdepth is None:
214
  maxdepth = maxsize
 
222
  if isinstance(unary_operators, str):
223
  unary_operators = [unary_operators]
224
  if X is None:
225
+ X, y = _using_test_input(X, test, y)
226
+
227
+ kwargs = dict(X=X, y=y, weights=weights,
228
+ alpha=alpha, annealing=annealing, batchSize=batchSize,
229
+ batching=batching, binary_operators=binary_operators,
230
+ equation_file=equation_file, fast_cycle=fast_cycle,
231
+ fractionReplaced=fractionReplaced,
232
+ ncyclesperiteration=ncyclesperiteration,
233
+ niterations=niterations, npop=npop,
234
+ topn=topn, verbosity=verbosity,
235
+ julia_optimization=julia_optimization, timeout=timeout,
236
+ fractionReplacedHof=fractionReplacedHof,
237
+ hofMigration=hofMigration,
238
+ limitPowComplexity=limitPowComplexity, maxdepth=maxdepth,
239
+ maxsize=maxsize, migration=migration, nrestarts=nrestarts,
240
+ parsimony=parsimony, perturbationFactor=perturbationFactor,
241
+ populations=populations, procs=procs,
242
+ shouldOptimizeConstants=shouldOptimizeConstants,
243
+ unary_operators=unary_operators, useFrequency=useFrequency,
244
+ use_custom_variable_names=use_custom_variable_names,
245
+ variable_names=variable_names, warmupMaxsize=warmupMaxsize,
246
+ weightAddNode=weightAddNode,
247
+ weightDeleteNode=weightDeleteNode,
248
+ weightDoNothing=weightDoNothing,
249
+ weightInsertNode=weightInsertNode,
250
+ weightMutateConstant=weightMutateConstant,
251
+ weightMutateOperator=weightMutateOperator,
252
+ weightRandomize=weightRandomize,
253
+ weightSimplify=weightSimplify,
254
+ constraints=constraints)
255
+
256
+ kwargs = {**_set_paths(tempdir), **kwargs}
257
+
258
+ kwargs['def_hyperparams'] = _metaprogram_fast_operator(**kwargs)
259
+
260
+ _handle_constraints(**kwargs)
261
+
262
+ kwargs['constraints_str'] = _make_constraints_str(**kwargs)
263
+ kwargs['def_hyperparams'] = _make_hyperparams_julia_str(**kwargs)
264
+ kwargs['def_auxiliary'] = _make_auxiliary_julia_str(**kwargs)
265
+ kwargs['def_datasets'] = _make_datasets_julia_str(**kwargs)
266
+
267
+ _create_julia_files(**kwargs)
268
+ _final_pysr_process(**kwargs)
269
+ _set_globals(**kwargs)
270
 
271
  if delete_tempfiles:
272
+ shutil.rmtree(kwargs['tmpdir'])
273
 
274
  return get_hof()
275
 
276
 
277
+ def _make_auxiliary_julia_str(julia_auxiliary_filenames, **kwargs):
278
  def_auxiliary = '\n'.join([
279
  f"""include("{_escape_filename(aux_fname)}")""" for aux_fname in julia_auxiliary_filenames
280
  ])
281
  return def_auxiliary
282
 
283
 
284
+ def _set_globals(X, equation_file, extra_sympy_mappings, variable_names, **kwargs):
285
  global global_n_features
286
  global global_equation_file
287
  global global_variable_names
 
292
  global_extra_sympy_mappings = extra_sympy_mappings
293
 
294
 
295
+ def _final_pysr_process(julia_optimization, procs, runfile_filename, timeout, **kwargs):
296
  command = [
297
  f'julia', f'-O{julia_optimization:d}',
298
  f'-p', f'{procs}',
 
316
 
317
 
318
  def _create_julia_files(auxiliary_filename, dataset_filename, def_auxiliary, def_datasets, def_hyperparams, fractionReplaced, hyperparam_filename,
319
+ ncyclesperiteration, niterations, npop, pkg_filename, runfile_filename, topn, verbosity, **kwargs):
320
  with open(hyperparam_filename, 'w') as f:
321
  print(def_hyperparams, file=f)
322
  with open(dataset_filename, 'w') as f:
 
334
  print(f'rmprocs(nprocs)', file=f)
335
 
336
 
337
+ def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename, **kwargs):
338
  def_datasets = """using DelimitedFiles"""
339
  np.savetxt(X_filename, X, delimiter=',')
340
  np.savetxt(y_filename, y, delimiter=',')
 
355
  parsimony, perturbationFactor, populations, procs, shouldOptimizeConstants,
356
  unary_operators, useFrequency, use_custom_variable_names, variable_names, warmupMaxsize, weightAddNode,
357
  weightDeleteNode, weightDoNothing, weightInsertNode, weightMutateConstant,
358
+ weightMutateOperator, weightRandomize, weightSimplify, weights, **kwargs):
359
  def_hyperparams += f"""include("{_escape_filename(operator_filename)}")
360
  {constraints_str}
361
  const binops = {'[' + ', '.join(binary_operators) + ']'}
 
434
  return def_hyperparams
435
 
436
 
437
+ def _make_constraints_str(binary_operators, constraints, unary_operators, **kwargs):
438
  constraints_str = "const una_constraints = ["
439
  first = True
440
  for op in unary_operators:
 
456
  return constraints_str
457
 
458
 
459
+ def _handle_constraints(binary_operators, constraints, unary_operators, **kwargs):
460
  for op in unary_operators:
461
  if op not in constraints:
462
  constraints[op] = -1
 
475
  constraints[op][0], constraints[op][1] = constraints[op][1], constraints[op][0]
476
 
477
 
478
+ def _metaprogram_fast_operator(binary_operators, unary_operators, **kwargs):
479
  def_hyperparams = ""
480
  for op_list in [binary_operators, unary_operators]:
481
  for i in range(len(op_list)):
 
493
  return def_hyperparams
494
 
495
 
496
+ def _using_test_input(X, test, y):
497
  if test == 'simple1':
498
  eval_str = "np.sign(X[:, 2])*np.abs(X[:, 2])**2.5 + 5*np.cos(X[:, 3]) - 5"
499
  elif test == 'simple2':
 
521
  return X, variable_names
522
 
523
 
524
+ def _set_paths(tempdir):
525
  # System-independent paths
526
  pkg_directory = Path(__file__).parents[1] / 'julia'
527
  pkg_filename = pkg_directory / "sr.jl"
 
539
  for fname in julia_auxiliaries
540
  ]
541
 
 
542
  tmpdir = Path(tempfile.mkdtemp(dir=tempdir))
543
  hyperparam_filename = tmpdir / f'hyperparams.jl'
544
  dataset_filename = tmpdir / f'dataset.jl'
 
547
  X_filename = tmpdir / "X.csv"
548
  y_filename = tmpdir / "y.csv"
549
  weights_filename = tmpdir / "weights.csv"
550
+ return dict(auxiliary_filename=auxiliary_filename, X_filename=X_filename,
551
+ dataset_filename=dataset_filename,
552
+ hyperparam_filename=hyperparam_filename,
553
+ julia_auxiliary_filenames=julia_auxiliary_filenames,
554
+ operator_filename=operator_filename, pkg_filename=pkg_filename,
555
+ runfile_filename=runfile_filename, tmpdir=tmpdir,
556
+ weights_filename=weights_filename, y_filename=y_filename)
557
 
558
 
559
  def check_assertions(X, binary_operators, unary_operators, use_custom_variable_names, variable_names, weights, y):