MilesCranmer commited on
Commit
b5b74c3
1 Parent(s): a9c6697

[WIP] begin implementation of multi-output compatibility

Browse files
Files changed (1) hide show
  1. pysr/sr.py +94 -39
pysr/sr.py CHANGED
@@ -19,6 +19,8 @@ global_equation_file = 'hall_of_fame.csv'
19
  global_n_features = None
20
  global_variable_names = []
21
  global_extra_sympy_mappings = {}
 
 
22
 
23
  sympy_mappings = {
24
  'div': lambda x, y : x/y,
@@ -276,6 +278,15 @@ def pysr(X=None, y=None, weights=None,
276
  if X is None:
277
  X, y = _using_test_input(X, test, y)
278
 
 
 
 
 
 
 
 
 
 
279
  kwargs = dict(X=X, y=y, weights=weights,
280
  alpha=alpha, annealing=annealing, batchSize=batchSize,
281
  batching=batching, binary_operators=binary_operators,
@@ -309,7 +320,8 @@ def pysr(X=None, y=None, weights=None,
309
  constraints=constraints,
310
  extra_sympy_mappings=extra_sympy_mappings,
311
  julia_project=julia_project, loss=loss,
312
- output_jax_format=output_jax_format)
 
313
 
314
  kwargs = {**_set_paths(tempdir), **kwargs}
315
 
@@ -358,15 +370,20 @@ def pysr(X=None, y=None, weights=None,
358
 
359
 
360
 
361
- def _set_globals(X, equation_file, extra_sympy_mappings, variable_names, **kwargs):
 
362
  global global_n_features
363
  global global_equation_file
364
  global global_variable_names
365
  global global_extra_sympy_mappings
 
 
366
  global_n_features = X.shape[1]
367
  global_equation_file = equation_file
368
  global_variable_names = variable_names
369
  global_extra_sympy_mappings = extra_sympy_mappings
 
 
370
 
371
 
372
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
@@ -438,17 +455,35 @@ def _create_julia_files(dataset_filename, def_datasets, hyperparam_filename, de
438
  print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
439
 
440
 
441
- def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename, **kwargs):
 
442
  def_datasets = """using DelimitedFiles"""
443
  np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
444
- np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
 
 
 
445
  if weights is not None:
446
- np.savetxt(weights_filename, weights.reshape(-1, 1), delimiter=',')
 
 
 
447
  def_datasets += f"""
448
- X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))
 
 
 
 
 
 
449
  y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
 
450
  if weights is not None:
451
- def_datasets += f"""
 
 
 
 
452
  weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
453
  return def_datasets
454
 
@@ -656,10 +691,10 @@ def _check_assertions(X, binary_operators, unary_operators, use_custom_variable_
656
  # Check for potential errors before they happen
657
  assert len(unary_operators) + len(binary_operators) > 0
658
  assert len(X.shape) == 2
659
- assert len(y.shape) == 1
660
  assert X.shape[0] == y.shape[0]
661
  if weights is not None:
662
- assert len(weights.shape) == 1
663
  assert X.shape[0] == weights.shape[0]
664
  if use_custom_variable_names:
665
  assert len(variable_names) == X.shape[1]
@@ -693,7 +728,8 @@ def run_feature_selection(X, y, select_k_features):
693
  return selector.get_support(indices=True)
694
 
695
  def get_hof(equation_file=None, n_features=None, variable_names=None,
696
- extra_sympy_mappings=None, output_jax_format=False, **kwargs):
 
697
  """Get the equations from a hall of fame file. If no arguments
698
  entered, the ones used previously from a call to PySR will be used."""
699
 
@@ -701,19 +737,28 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
701
  global global_equation_file
702
  global global_variable_names
703
  global global_extra_sympy_mappings
 
 
704
 
705
  if equation_file is None: equation_file = global_equation_file
706
  if n_features is None: n_features = global_n_features
707
  if variable_names is None: variable_names = global_variable_names
708
  if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
 
 
709
 
710
  global_equation_file = equation_file
711
  global_n_features = n_features
712
  global_variable_names = variable_names
713
  global_extra_sympy_mappings = extra_sympy_mappings
 
 
714
 
715
  try:
716
- output = pd.read_csv(str(equation_file) + '.bkup', sep="|")
 
 
 
717
  except FileNotFoundError:
718
  raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
719
 
@@ -735,42 +780,52 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
735
  else:
736
  sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
737
 
738
- for i in range(len(output)):
739
- eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
740
- sympy_format.append(eqn)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
741
  if output_jax_format:
742
- func, params = sympy2jax(eqn, sympy_symbols)
743
- jax_format.append({'callable': func, 'parameters': params})
744
- lambda_format.append(lambdify(sympy_symbols, eqn))
745
- curMSE = output.loc[i, 'MSE']
746
- curComplexity = output.loc[i, 'Complexity']
747
-
748
- if lastMSE is None:
749
- cur_score = 0.0
750
- else:
751
- cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
752
 
753
- scores.append(cur_score)
754
- lastMSE = curMSE
755
- lastComplexity = curComplexity
756
-
757
- output['score'] = np.array(scores)
758
- output['sympy_format'] = sympy_format
759
- output['lambda_format'] = lambda_format
760
- output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
761
- if output_jax_format:
762
- output_cols += ['jax_format']
763
- output['jax_format'] = jax_format
764
 
765
- return output[output_cols]
 
 
 
766
 
767
  def best_row(equations=None):
768
  """Return the best row of a hall of fame file using the score column.
769
  By default this uses the last equation file.
770
  """
771
- if equations is None: equations = get_hof()
772
- best_idx = np.argmax(equations['score'])
773
- return equations.iloc[best_idx]
 
 
 
774
 
775
  def best_tex(equations=None):
776
  """Return the equation with the best score, in latex format
 
19
  global_n_features = None
20
  global_variable_names = []
21
  global_extra_sympy_mappings = {}
22
+ global_multioutput = False
23
+ global_nout = 1
24
 
25
  sympy_mappings = {
26
  'div': lambda x, y : x/y,
 
278
  if X is None:
279
  X, y = _using_test_input(X, test, y)
280
 
281
+ if len(y.shape) == 2:
282
+ multioutput = True
283
+ nout = y.shape[1]
284
+ elif len(y.shape) == 1:
285
+ multioutput = False
286
+ nout = 1
287
+ else:
288
+ raise NotImplementedError("y shape not supported!")
289
+
290
  kwargs = dict(X=X, y=y, weights=weights,
291
  alpha=alpha, annealing=annealing, batchSize=batchSize,
292
  batching=batching, binary_operators=binary_operators,
 
320
  constraints=constraints,
321
  extra_sympy_mappings=extra_sympy_mappings,
322
  julia_project=julia_project, loss=loss,
323
+ output_jax_format=output_jax_format,
324
+ multioutput=multioutput, nout=nout)
325
 
326
  kwargs = {**_set_paths(tempdir), **kwargs}
327
 
 
370
 
371
 
372
 
373
+ def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
374
+ multioutput, **kwargs):
375
  global global_n_features
376
  global global_equation_file
377
  global global_variable_names
378
  global global_extra_sympy_mappings
379
+ global global_multioutput
380
+ global global_nout
381
  global_n_features = X.shape[1]
382
  global_equation_file = equation_file
383
  global_variable_names = variable_names
384
  global_extra_sympy_mappings = extra_sympy_mappings
385
+ global_multioutput = multioutput
386
+ global_nout = nout
387
 
388
 
389
  def _final_pysr_process(julia_optimization, runfile_filename, timeout, **kwargs):
 
455
  print(f'EquationSearch(X, y, niterations={niterations:d}, varMap={varMap}, options=options, numprocs={procs})', file=f)
456
 
457
 
458
+ def _make_datasets_julia_str(X, X_filename, weights, weights_filename, y, y_filename,
459
+ multioutput, **kwargs):
460
  def_datasets = """using DelimitedFiles"""
461
  np.savetxt(X_filename, X.astype(np.float32), delimiter=',')
462
+ if multioutput:
463
+ np.savetxt(y_filename, y.astype(np.float32), delimiter=',')
464
+ else:
465
+ np.savetxt(y_filename, y.reshape(-1, 1).astype(np.float32), delimiter=',')
466
  if weights is not None:
467
+ if multioutput:
468
+ np.savetxt(weights_filename, weights.astype(np.float32), delimiter=',')
469
+ else:
470
+ np.savetxt(weights_filename, weights.reshape(-1, 1).astype(np.float32), delimiter=',')
471
  def_datasets += f"""
472
+ X = copy(transpose(readdlm("{_escape_filename(X_filename)}", ',', Float32, '\\n')))"""
473
+
474
+ if multioutput:
475
+ def_datasets+= f"""
476
+ y = copy(transpose(readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')))"""
477
+ else:
478
+ def_datasets+= f"""
479
  y = readdlm("{_escape_filename(y_filename)}", ',', Float32, '\\n')[:, 1]"""
480
+
481
  if weights is not None:
482
+ if multioutput:
483
+ def_datasets += f"""
484
+ weights = copy(transpose(readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')))"""
485
+ else:
486
+ def_datasets += f"""
487
  weights = readdlm("{_escape_filename(weights_filename)}", ',', Float32, '\\n')[:, 1]"""
488
  return def_datasets
489
 
 
691
  # Check for potential errors before they happen
692
  assert len(unary_operators) + len(binary_operators) > 0
693
  assert len(X.shape) == 2
694
+ assert len(y.shape) in [1, 2]
695
  assert X.shape[0] == y.shape[0]
696
  if weights is not None:
697
+ assert weights.shape == y.shape
698
  assert X.shape[0] == weights.shape[0]
699
  if use_custom_variable_names:
700
  assert len(variable_names) == X.shape[1]
 
728
  return selector.get_support(indices=True)
729
 
730
  def get_hof(equation_file=None, n_features=None, variable_names=None,
731
+ extra_sympy_mappings=None, output_jax_format=False,
732
+ multioutput=False, nout=1, **kwargs):
733
  """Get the equations from a hall of fame file. If no arguments
734
  entered, the ones used previously from a call to PySR will be used."""
735
 
 
737
  global global_equation_file
738
  global global_variable_names
739
  global global_extra_sympy_mappings
740
+ global global_multioutput
741
+ global global_nout
742
 
743
  if equation_file is None: equation_file = global_equation_file
744
  if n_features is None: n_features = global_n_features
745
  if variable_names is None: variable_names = global_variable_names
746
  if extra_sympy_mappings is None: extra_sympy_mappings = global_extra_sympy_mappings
747
+ if multioutput is None: multioutput = global_multioutput
748
+ if nout is None: nout = global_nout
749
 
750
  global_equation_file = equation_file
751
  global_n_features = n_features
752
  global_variable_names = variable_names
753
  global_extra_sympy_mappings = extra_sympy_mappings
754
+ global_multioutput = multioutput
755
+ global_nout = nout
756
 
757
  try:
758
+ if multioutput:
759
+ all_outputs = [pd.read_csv(f'out{i}_' + str(equation_file) + '.bkup', sep="|") for i in range(1, nout+1)]
760
+ else:
761
+ all_outputs = [pd.read_csv(str(equation_file) + '.bkup', sep="|")]
762
  except FileNotFoundError:
763
  raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
764
 
 
780
  else:
781
  sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
782
 
783
+ ret_outputs = []
784
+ for output in all_outputs:
785
+ for i in range(len(output)):
786
+ eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
787
+ sympy_format.append(eqn)
788
+ if output_jax_format:
789
+ func, params = sympy2jax(eqn, sympy_symbols)
790
+ jax_format.append({'callable': func, 'parameters': params})
791
+ lambda_format.append(lambdify(sympy_symbols, eqn))
792
+ curMSE = output.loc[i, 'MSE']
793
+ curComplexity = output.loc[i, 'Complexity']
794
+
795
+ if lastMSE is None:
796
+ cur_score = 0.0
797
+ else:
798
+ cur_score = - np.log(curMSE/lastMSE)/(curComplexity - lastComplexity)
799
+
800
+ scores.append(cur_score)
801
+ lastMSE = curMSE
802
+ lastComplexity = curComplexity
803
+
804
+ output['score'] = np.array(scores)
805
+ output['sympy_format'] = sympy_format
806
+ output['lambda_format'] = lambda_format
807
+ output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
808
  if output_jax_format:
809
+ output_cols += ['jax_format']
810
+ output['jax_format'] = jax_format
 
 
 
 
 
 
 
 
811
 
812
+ ret_outputs.append(output[output_cols])
 
 
 
 
 
 
 
 
 
 
813
 
814
+ if multioutput:
815
+ return ret_outputs
816
+ else:
817
+ return ret_outputs[0]
818
 
819
  def best_row(equations=None):
820
  """Return the best row of a hall of fame file using the score column.
821
  By default this uses the last equation file.
822
  """
823
+ if equations is None: all_eqs = get_hof()
824
+ if isinstance(all_eqs, list):
825
+ return [equations[j].iloc[np.argmax(equations[j]['score'])] for j in range(len(all_eqs))]
826
+ else:
827
+ best_idx = np.argmax(equations['score'])
828
+ return equations.iloc[best_idx]
829
 
830
  def best_tex(equations=None):
831
  """Return the equation with the best score, in latex format