MilesCranmer commited on
Commit
0087907
·
unverified ·
2 Parent(s): 569f4ba 14dfd82

Merge pull request #46 from MilesCranmer/multi-output

Browse files
Files changed (3) hide show
  1. Project.toml +1 -1
  2. pysr/sr.py +123 -61
  3. test/test.py +4 -3
Project.toml CHANGED
@@ -2,5 +2,5 @@
2
  SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
3
 
4
  [compat]
5
- SymbolicRegression = "0.5.16"
6
  julia = "1.5"
 
2
  SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
3
 
4
  [compat]
5
+ SymbolicRegression = "0.6.0"
6
  julia = "1.5"
pysr/sr.py CHANGED
@@ -19,6 +19,8 @@ global_equation_file = 'hall_of_fame.csv'
19
  global_n_features = None
20
  global_variable_names = []
21
  global_extra_sympy_mappings = {}
 
 
22
 
23
  sympy_mappings = {
24
  'div': lambda x, y : x/y,
@@ -276,6 +278,16 @@ def pysr(X=None, y=None, weights=None,
276
  if X is None:
277
  X, y = _using_test_input(X, test, y)
278
 
 
 
 
 
 
 
 
 
 
 
279
  kwargs = dict(X=X, y=y, weights=weights,
280
  alpha=alpha, annealing=annealing, batchSize=batchSize,
281
  batching=batching, binary_operators=binary_operators,
@@ -309,7 +321,8 @@ def pysr(X=None, y=None, weights=None,
309
  constraints=constraints,
310
  extra_sympy_mappings=extra_sympy_mappings,
311
  julia_project=julia_project, loss=loss,
312
- output_jax_format=output_jax_format)
 
313
 
314
  kwargs = {**_set_paths(tempdir), **kwargs}
315
 
@@ -358,15 +371,20 @@ def pysr(X=None, y=None, weights=None,
358
 
359
 
360
 
361
- def _set_globals(X, equation_file, extra_sympy_mappings, variable_names, **kwargs):
 
362
  global global_n_features
363
  global global_equation_file
364
  global global_variable_names
365
  global global_extra_sympy_mappings
 
 
366
  global_n_features = X.shape[1]
367
  global_equation_file = equation_file
368
  global_variable_names = variable_names
369
  global_extra_sympy_mappings = extra_sympy_mappings
 
 
370
 
371
 
372
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
@@ -393,9 +411,7 @@ def _cmd_runner(command, **kwargs):
393
  .replace('\\r', '\r')
394
  .encode(sys.stdout.encoding, errors='replace'))
395
 
396
- print(decoded_line, end='')
397
-
398
-
399
 
400
  process.stdout.close()
401
  process.wait()
@@ -438,17 +454,35 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
438
  print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
439
 
440
 
441
- def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename, **kwargs):
 
442
  def_datasets = """using DelimitedFiles"""
443
  np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
444
- np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
 
 
 
445
  if weights is not None:
446
- np.savetxt(weights_filename, weights.reshape(-1, 1), delimiter=',')
 
 
 
447
  def_datasets += f"""
448
- X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))
 
 
 
 
 
 
449
  y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
 
450
  if weights is not None:
451
- def_datasets += f"""
 
 
 
 
452
  weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
453
  return def_datasets
454
 
@@ -656,10 +690,10 @@ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_
656
  # Check for potential errors before they happen
657
  assert len(unary_operators) + len(binary_operators) > 0
658
  assert len(X.shape) == 2
659
- assert len(y.shape) == 1
660
  assert X.shape[0] == y.shape[0]
661
  if weights is not None:
662
- assert len(weights.shape) == 1
663
  assert X.shape[0] == weights.shape[0]
664
  if use_custom_variable_names:
665
  assert len(variable_names) == X.shape[1]
@@ -693,7 +727,8 @@ def run_feature_selection(X, y, select_k_features):
693
  return selector.get_support(indices=True)
694
 
695
  def get_hof(equation_file=None, n_features=None, variable_names=None,
696
- extra_sympy_mappings=None, output_jax_format=False, **kwargs):
 
697
  """Get the equations from a hall of fame file. If no arguments
698
  entered, the ones used previously from a call to PySR will be used."""
699
 
@@ -701,99 +736,126 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
701
  global global_equation_file
702
  global global_variable_names
703
  global global_extra_sympy_mappings
 
 
704
 
705
  if equation_file is None: equation_file = global_equation_file
706
  if n_features is None: n_features = global_n_features
707
  if variable_names is None: variable_names = global_variable_names
708
  if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
 
 
709
 
710
  global_equation_file = equation_file
711
  global_n_features = n_features
712
  global_variable_names = variable_names
713
  global_extra_sympy_mappings = extra_sympy_mappings
 
 
714
 
715
  try:
716
- output = pd.read_csv(str(equation_file) + '.bkup', sep="|")
 
 
 
717
  except FileNotFoundError:
718
  raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
719
 
720
- scores = []
721
- lastMSE = None
722
- lastComplexity = 0
723
- sympy_format = []
724
- lambda_format = []
725
- if output_jax_format:
726
- jax_format = []
727
- use_custom_variable_names = (len(variable_names) != 0)
728
- local_sympy_mappings = {
729
- **extra_sympy_mappings,
730
- **sympy_mappings
731
- }
732
 
733
- if use_custom_variable_names:
734
- sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
735
- else:
736
- sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
737
 
738
- for i in range(len(output)):
739
- eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
740
- sympy_format.append(eqn)
 
 
741
  if output_jax_format:
742
- func, params = sympy2jax(eqn, sympy_symbols)
743
- jax_format.append({'callable': func, 'parameters': params})
744
- lambda_format.append(lambdify(sympy_symbols, eqn))
745
- curMSE = output.loc[i, 'MSE']
746
- curComplexity = output.loc[i, 'Complexity']
747
-
748
- if lastMSE is None:
749
- cur_score = 0.0
750
- else:
751
- cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
752
 
753
- scores.append(cur_score)
754
- lastMSE = curMSE
755
- lastComplexity = curComplexity
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
- output['score'] = np.array(scores)
758
- output['sympy_format'] = sympy_format
759
- output['lambda_format'] = lambda_format
760
- output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
761
- if output_jax_format:
762
- output_cols += ['jax_format']
763
- output['jax_format'] = jax_format
764
 
765
- return output[output_cols]
 
 
 
766
 
767
  def best_row(equations=None):
768
  """Return the best row of a hall of fame file using the score column.
769
  By default this uses the last equation file.
770
  """
771
  if equations is None: equations = get_hof()
772
- best_idx = np.argmax(equations['score'])
773
- return equations.iloc[best_idx]
 
 
774
 
775
  def best_tex(equations=None):
776
  """Return the equation with the best score, in latex format
777
  By default this uses the last equation file.
778
  """
779
  if equations is None: equations = get_hof()
780
- best_sympy = best_row(equations)['sympy_format']
781
- return sympy.latex(best_sympy.simplify())
 
 
782
 
783
  def best(equations=None):
784
  """Return the equation with the best score, in sympy format.
785
  By default this uses the last equation file.
786
  """
787
  if equations is None: equations = get_hof()
788
- best_sympy = best_row(equations)['sympy_format']
789
- return best_sympy.simplify()
 
 
790
 
791
  def best_callable(equations=None):
792
  """Return the equation with the best score, in callable format.
793
  By default this uses the last equation file.
794
  """
795
  if equations is None: equations = get_hof()
796
- return best_row(equations)['lambda_format']
 
 
 
797
 
798
  def _escape_filename(filename):
799
  """Turns a file into a string representation with correctly escaped backslashes"""
 
19
  global_n_features = None
20
  global_variable_names = []
21
  global_extra_sympy_mappings = {}
22
+ global_multioutput = False
23
+ global_nout = 1
24
 
25
  sympy_mappings = {
26
  'div': lambda x, y : x/y,
 
278
  if X is None:
279
  X, y = _using_test_input(X, test, y)
280
 
281
+ if len(y.shape) == 1 or (len(y.shape) == 2 and y.shape[1] == 1):
282
+ multioutput = False
283
+ nout = 1
284
+ y = y.reshape(-1)
285
+ elif len(y.shape) == 2:
286
+ multioutput = True
287
+ nout = y.shape[1]
288
+ else:
289
+ raise NotImplementedError("y shape not supported!")
290
+
291
  kwargs = dict(X=X, y=y, weights=weights,
292
  alpha=alpha, annealing=annealing, batchSize=batchSize,
293
  batching=batching, binary_operators=binary_operators,
 
321
  constraints=constraints,
322
  extra_sympy_mappings=extra_sympy_mappings,
323
  julia_project=julia_project, loss=loss,
324
+ output_jax_format=output_jax_format,
325
+ multioutput=multioutput, nout=nout)
326
 
327
  kwargs = {**_set_paths(tempdir), **kwargs}
328
 
 
371
 
372
 
373
 
374
+ def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
375
+ multioutput, nout, **kwargs):
376
  global global_n_features
377
  global global_equation_file
378
  global global_variable_names
379
  global global_extra_sympy_mappings
380
+ global global_multioutput
381
+ global global_nout
382
  global_n_features = X.shape[1]
383
  global_equation_file = equation_file
384
  global_variable_names = variable_names
385
  global_extra_sympy_mappings = extra_sympy_mappings
386
+ global_multioutput = multioutput
387
+ global_nout = nout
388
 
389
 
390
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
 
411
  .replace('\\r', '\r')
412
  .encode(sys.stdout.encoding, errors='replace'))
413
 
414
+ sys.stdout.buffer.write(decoded_line)
 
 
415
 
416
  process.stdout.close()
417
  process.wait()
 
454
  print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
455
 
456
 
457
+ def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
458
+ multioutput, **kwargs):
459
  def_datasets = """using DelimitedFiles"""
460
  np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
461
+ if multioutput:
462
+ np.savetxt(y_filename, y.astype(np.float32), delimiter=',')
463
+ else:
464
+ np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
465
  if weights is not None:
466
+ if multioutput:
467
+ np.savetxt(weights_filename, weights.astype(np.float32), delimiter=',')
468
+ else:
469
+ np.savetxt(weights_filename, weights.reshape(-1, 1).astype(np.float32), delimiter=',')
470
  def_datasets += f"""
471
+ X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
472
+
473
+ if multioutput:
474
+ def_datasets+= f"""
475
+ y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
476
+ else:
477
+ def_datasets+= f"""
478
  y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
479
+
480
  if weights is not None:
481
+ if multioutput:
482
+ def_datasets += f"""
483
+ weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')))"""
484
+ else:
485
+ def_datasets += f"""
486
  weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
487
  return def_datasets
488
 
 
690
  # Check for potential errors before they happen
691
  assert len(unary_operators) + len(binary_operators) > 0
692
  assert len(X.shape) == 2
693
+ assert len(y.shape) in [1, 2]
694
  assert X.shape[0] == y.shape[0]
695
  if weights is not None:
696
+ assert weights.shape == y.shape
697
  assert X.shape[0] == weights.shape[0]
698
  if use_custom_variable_names:
699
  assert len(variable_names) == X.shape[1]
 
727
  return selector.get_support(indices=True)
728
 
729
  def get_hof(equation_file=None, n_features=None, variable_names=None,
730
+ extra_sympy_mappings=None, output_jax_format=False,
731
+ multioutput=None, nout=None, **kwargs):
732
  """Get the equations from a hall of fame file. If no arguments
733
  entered, the ones used previously from a call to PySR will be used."""
734
 
 
736
  global global_equation_file
737
  global global_variable_names
738
  global global_extra_sympy_mappings
739
+ global global_multioutput
740
+ global global_nout
741
 
742
  if equation_file is None: equation_file = global_equation_file
743
  if n_features is None: n_features = global_n_features
744
  if variable_names is None: variable_names = global_variable_names
745
  if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
746
+ if multioutput is None: multioutput = global_multioutput
747
+ if nout is None: nout = global_nout
748
 
749
  global_equation_file = equation_file
750
  global_n_features = n_features
751
  global_variable_names = variable_names
752
  global_extra_sympy_mappings = extra_sympy_mappings
753
+ global_multioutput = multioutput
754
+ global_nout = nout
755
 
756
  try:
757
+ if multioutput:
758
+ all_outputs = [pd.read_csv(f'out{i}_' + str(equation_file) + '.bkup', sep="|") for i in range(1, nout+1)]
759
+ else:
760
+ all_outputs = [pd.read_csv(str(equation_file) + '.bkup', sep="|")]
761
  except FileNotFoundError:
762
  raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
763
 
764
+ ret_outputs = []
 
 
 
 
 
 
 
 
 
 
 
765
 
766
+ for output in all_outputs:
 
 
 
767
 
768
+ scores = []
769
+ lastMSE = None
770
+ lastComplexity = 0
771
+ sympy_format = []
772
+ lambda_format = []
773
  if output_jax_format:
774
+ jax_format = []
775
+ use_custom_variable_names = (len(variable_names) != 0)
776
+ local_sympy_mappings = {
777
+ **extra_sympy_mappings,
778
+ **sympy_mappings
779
+ }
 
 
 
 
780
 
781
+ if use_custom_variable_names:
782
+ sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
783
+ else:
784
+ sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
785
+
786
+ for i in range(len(output)):
787
+ eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
788
+ sympy_format.append(eqn)
789
+ if output_jax_format:
790
+ func, params = sympy2jax(eqn, sympy_symbols)
791
+ jax_format.append({'callable': func, 'parameters': params})
792
+ lambda_format.append(lambdify(sympy_symbols, eqn))
793
+ curMSE = output.loc[i, 'MSE']
794
+ curComplexity = output.loc[i, 'Complexity']
795
+
796
+ if lastMSE is None:
797
+ cur_score = 0.0
798
+ else:
799
+ cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
800
+
801
+ scores.append(cur_score)
802
+ lastMSE = curMSE
803
+ lastComplexity = curComplexity
804
+
805
+ output['score'] = np.array(scores)
806
+ output['sympy_format'] = sympy_format
807
+ output['lambda_format'] = lambda_format
808
+ output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
809
+ if output_jax_format:
810
+ output_cols += ['jax_format']
811
+ output['jax_format'] = jax_format
812
 
813
+ ret_outputs.append(output[output_cols])
 
 
 
 
 
 
814
 
815
+ if multioutput:
816
+ return ret_outputs
817
+ else:
818
+ return ret_outputs[0]
819
 
820
  def best_row(equations=None):
821
  """Return the best row of a hall of fame file using the score column.
822
  By default this uses the last equation file.
823
  """
824
  if equations is None: equations = get_hof()
825
+ if isinstance(equations, list):
826
+ return [eq.iloc[np.argmax(eq['score'])] for eq in equations]
827
+ else:
828
+ return equations.iloc[np.argmax(equations['score'])]
829
 
830
  def best_tex(equations=None):
831
  """Return the equation with the best score, in latex format
832
  By default this uses the last equation file.
833
  """
834
  if equations is None: equations = get_hof()
835
+ if isinstance(equations, list):
836
+ return [sympy.latex(best_row(eq)['sympy_format'].simplify()) for eq in equations]
837
+ else:
838
+ return sympy.latex(best_row(equations)['sympy_format'].simplify())
839
 
840
  def best(equations=None):
841
  """Return the equation with the best score, in sympy format.
842
  By default this uses the last equation file.
843
  """
844
  if equations is None: equations = get_hof()
845
+ if isinstance(equations, list):
846
+ return [best_row(eq)['sympy_format'].simplify() for eq in equations]
847
+ else:
848
+ return best_row(equations)['sympy_format'].simplify()
849
 
850
  def best_callable(equations=None):
851
  """Return the equation with the best score, in callable format.
852
  By default this uses the last equation file.
853
  """
854
  if equations is None: equations = get_hof()
855
+ if isinstance(equations, list):
856
+ return [best_row(eq)['lambda_format'] for eq in equations]
857
+ else:
858
+ return best_row(equations)['lambda_format']
859
 
860
  def _escape_filename(filename):
861
  """Turns a file into a string representation with correctly escaped backslashes"""
test/test.py CHANGED
@@ -17,14 +17,15 @@ equations = pysr(X, y, **default_test_kwargs)
17
  print(equations)
18
  assert equations.iloc[-1]['MSE'] < 1e-4
19
 
20
- print("Test 2 - test custom operator")
21
- y = X[:, 0]**2
22
  equations = pysr(X, y,
23
  unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
24
  extra_sympy_mappings={'square': lambda x: x**2},
25
  **default_test_kwargs)
26
  print(equations)
27
- assert equations.iloc[-1]['MSE'] < 1e-4
 
28
 
29
  X = np.random.randn(100, 1)
30
  y = X[:, 0] + 3.0
 
17
  print(equations)
18
  assert equations.iloc[-1]['MSE'] < 1e-4
19
 
20
+ print("Test 2 - test custom operator, and multiple outputs")
21
+ y = X[:, [0, 1]]**2
22
  equations = pysr(X, y,
23
  unary_operators=["sq(x) = x^2"], binary_operators=["plus"],
24
  extra_sympy_mappings={'square': lambda x: x**2},
25
  **default_test_kwargs)
26
  print(equations)
27
+ assert equations[0].iloc[-1]['MSE'] < 1e-4
28
+ assert equations[1].iloc[-1]['MSE'] < 1e-4
29
 
30
  X = np.random.randn(100, 1)
31
  y = X[:, 0] + 3.0