MilesCranmer commited on
Commit
d3b73f7
1 Parent(s): 7847c48

Fix multi-output scoring

Browse files
Files changed (1) hide show
  1. pysr/sr.py +24 -20
pysr/sr.py CHANGED
@@ -372,7 +372,7 @@ def pysr(X=None, y=None, weights=None,
372
 
373
 
374
  def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
375
- multioutput, **kwargs):
376
  global global_n_features
377
  global global_equation_file
378
  global global_variable_names
@@ -730,7 +730,7 @@ def run_feature_selection(X, y, select_k_features):
730
 
731
  def get_hof(equation_file=None, n_features=None, variable_names=None,
732
  extra_sympy_mappings=None, output_jax_format=False,
733
- multioutput=False, nout=1, **kwargs):
734
  """Get the equations from a hall of fame file. If no arguments
735
  entered, the ones used previously from a call to PySR will be used."""
736
 
@@ -763,26 +763,28 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
763
  except FileNotFoundError:
764
  raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
765
 
766
- scores = []
767
- lastMSE = None
768
- lastComplexity = 0
769
- sympy_format = []
770
- lambda_format = []
771
- if output_jax_format:
772
- jax_format = []
773
- use_custom_variable_names = (len(variable_names) != 0)
774
- local_sympy_mappings = {
775
- **extra_sympy_mappings,
776
- **sympy_mappings
777
- }
778
-
779
- if use_custom_variable_names:
780
- sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
781
- else:
782
- sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
783
-
784
  ret_outputs = []
 
785
  for output in all_outputs:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
786
  for i in range(len(output)):
787
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
788
  sympy_format.append(eqn)
@@ -842,6 +844,7 @@ def best(equations=None):
842
  By default this uses the last equation file.
843
  """
844
  if equations is None: equations = get_hof()
 
845
  return [best_row(eq)['sympy_format'].simplify() for eq in equations]
846
  else:
847
  return best_row(equations)['sympy_format'].simplify()
@@ -851,6 +854,7 @@ def best_callable(equations=None):
851
  By default this uses the last equation file.
852
  """
853
  if equations is None: equations = get_hof()
 
854
  return [best_row(eq)['lambda_format'] for eq in equations]
855
  else:
856
  return best_row(equations)['lambda_format']
 
372
 
373
 
374
  def _set_globals(X, equation_file, extra_sympy_mappings, variable_names,
375
+ multioutput, nout, **kwargs):
376
  global global_n_features
377
  global global_equation_file
378
  global global_variable_names
 
730
 
731
  def get_hof(equation_file=None, n_features=None, variable_names=None,
732
  extra_sympy_mappings=None, output_jax_format=False,
733
+ multioutput=False, nout=None, **kwargs):
734
  """Get the equations from a hall of fame file. If no arguments
735
  entered, the ones used previously from a call to PySR will be used."""
736
 
 
763
  except FileNotFoundError:
764
  raise RuntimeError("Couldn't find equation file! The equation search likely exited before a single iteration completed.")
765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
  ret_outputs = []
767
+
768
  for output in all_outputs:
769
+
770
+ scores = []
771
+ lastMSE = None
772
+ lastComplexity = 0
773
+ sympy_format = []
774
+ lambda_format = []
775
+ if output_jax_format:
776
+ jax_format = []
777
+ use_custom_variable_names = (len(variable_names) != 0)
778
+ local_sympy_mappings = {
779
+ **extra_sympy_mappings,
780
+ **sympy_mappings
781
+ }
782
+
783
+ if use_custom_variable_names:
784
+ sympy_symbols = [sympy.Symbol(variable_names[i]) for i in range(n_features)]
785
+ else:
786
+ sympy_symbols = [sympy.Symbol('x%d'%i) for i in range(n_features)]
787
+
788
  for i in range(len(output)):
789
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
790
  sympy_format.append(eqn)
 
844
  By default this uses the last equation file.
845
  """
846
  if equations is None: equations = get_hof()
847
+ if isinstance(equations, list):
848
  return [best_row(eq)['sympy_format'].simplify() for eq in equations]
849
  else:
850
  return best_row(equations)['sympy_format'].simplify()
 
854
  By default this uses the last equation file.
855
  """
856
  if equations is None: equations = get_hof()
857
+ if isinstance(equations, list):
858
  return [best_row(eq)['lambda_format'] for eq in equations]
859
  else:
860
  return best_row(equations)['lambda_format']