MilesCranmer commited on
Commit
898f500
1 Parent(s): 3772652

Add mechanism for extracting JAX functions

Browse files
Files changed (2) hide show
  1. pysr/sr.py +18 -4
  2. setup.py +1 -1
pysr/sr.py CHANGED
@@ -12,7 +12,7 @@ import shutil
12
  from pathlib import Path
13
  from datetime import datetime
14
  import warnings
15
-
16
 
17
  global_equation_file = 'hall_of_fame.csv'
18
  global_n_features = None
@@ -106,6 +106,7 @@ def pysr(X=None, y=None, weights=None,
106
  user_input=True,
107
  update=True,
108
  temp_equation_file=False,
 
109
  warmupMaxsize=None, #Deprecated
110
  ):
111
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
@@ -216,6 +217,8 @@ def pysr(X=None, y=None, weights=None,
216
  :param temp_equation_file: Whether to put the hall of fame file in
217
  the temp directory. Deletion is then controlled with the
218
  delete_tempfiles argument.
 
 
219
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
220
  (as strings).
221
 
@@ -281,7 +284,8 @@ def pysr(X=None, y=None, weights=None,
281
  weightSimplify=weightSimplify,
282
  constraints=constraints,
283
  extra_sympy_mappings=extra_sympy_mappings,
284
- julia_project=julia_project, loss=loss)
 
285
 
286
  kwargs = {**_set_paths(tempdir), **kwargs}
287
 
@@ -633,7 +637,8 @@ def run_feature_selection(X, y, select_k_features):
633
  max_features=select_k_features, prefit=True)
634
  return selector.get_support(indices=True)
635
 
636
- def get_hof(equation_file=None, n_features=None, variable_names=None, extra_sympy_mappings=None, **kwargs):
 
637
  """Get the equations from a hall of fame file. If no arguments
638
  entered, the ones used previously from a call to PySR will be used."""
639
 
@@ -663,6 +668,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
663
  lastComplexity = 0
664
  sympy_format = []
665
  lambda_format = []
 
 
666
  use_custom_variable_names = (len(variable_names) != 0)
667
  local_sympy_mappings = {
668
  **extra_sympy_mappings,
@@ -677,6 +684,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
677
  for i in range(len(output)):
678
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
679
  sympy_format.append(eqn)
 
 
 
680
  lambda_format.append(lambdify(sympy_symbols, eqn))
681
  curMSE = output.loc[i, 'MSE']
682
  curComplexity = output.loc[i, 'Complexity']
@@ -693,8 +703,12 @@ def get_hof(equation_file=None, n_features=None, variable_names=None, extra_symp
693
  output['score'] = np.array(scores)
694
  output['sympy_format'] = sympy_format
695
  output['lambda_format'] = lambda_format
 
 
 
 
696
 
697
- return output[['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']]
698
 
699
  def best_row(equations=None):
700
  """Return the best row of a hall of fame file using the score column.
 
12
  from pathlib import Path
13
  from datetime import datetime
14
  import warnings
15
+ from .export import sympy2jax
16
 
17
  global_equation_file = 'hall_of_fame.csv'
18
  global_n_features = None
 
106
  user_input=True,
107
  update=True,
108
  temp_equation_file=False,
109
+ output_jax_format=False,
110
  warmupMaxsize=None, #Deprecated
111
  ):
112
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
 
217
  :param temp_equation_file: Whether to put the hall of fame file in
218
  the temp directory. Deletion is then controlled with the
219
  delete_tempfiles argument.
220
+ :param output_jax_format: Whether to create a 'jax_format' column in the output,
221
+ containing jax-callable functions and the default parameters in a jax array.
222
  :returns: pd.DataFrame, Results dataframe, giving complexity, MSE, and equations
223
  (as strings).
224
 
 
284
  weightSimplify=weightSimplify,
285
  constraints=constraints,
286
  extra_sympy_mappings=extra_sympy_mappings,
287
+ julia_project=julia_project, loss=loss,
288
+ output_jax_format=output_jax_format)
289
 
290
  kwargs = {**_set_paths(tempdir), **kwargs}
291
 
 
637
  max_features=select_k_features, prefit=True)
638
  return selector.get_support(indices=True)
639
 
640
+ def get_hof(equation_file=None, n_features=None, variable_names=None,
641
+ extra_sympy_mappings=None, output_jax_format=False, **kwargs):
642
  """Get the equations from a hall of fame file. If no arguments
643
  entered, the ones used previously from a call to PySR will be used."""
644
 
 
668
  lastComplexity = 0
669
  sympy_format = []
670
  lambda_format = []
671
+ if output_jax_format:
672
+ jax_format = []
673
  use_custom_variable_names = (len(variable_names) != 0)
674
  local_sympy_mappings = {
675
  **extra_sympy_mappings,
 
684
  for i in range(len(output)):
685
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
686
  sympy_format.append(eqn)
687
+ if output_jax_format:
688
+ func, params = sympy2jax(eqn, sympy_symbols)
689
+ jax_format.append({'callable': func, 'parameters': parameters})
690
  lambda_format.append(lambdify(sympy_symbols, eqn))
691
  curMSE = output.loc[i, 'MSE']
692
  curComplexity = output.loc[i, 'Complexity']
 
703
  output['score'] = np.array(scores)
704
  output['sympy_format'] = sympy_format
705
  output['lambda_format'] = lambda_format
706
+ output_cols = ['Complexity', 'MSE', 'score', 'Equation', 'sympy_format', 'lambda_format']
707
+ if output_jax_format:
708
+ output_cols += 'jax_format'
709
+ output['jax_format'] = jax_format
710
 
711
+ return output[output_cols]
712
 
713
  def best_row(equations=None):
714
  """Return the best row of a hall of fame file using the score column.
setup.py CHANGED
@@ -5,7 +5,7 @@ with open("README.md", "r") as fh:
5
 
6
  setuptools.setup(
7
  name="pysr", # Replace with your own username
8
- version="0.5.12",
9
  author="Miles Cranmer",
10
  author_email="[email protected]",
11
  description="Simple and efficient symbolic regression",
 
5
 
6
  setuptools.setup(
7
  name="pysr", # Replace with your own username
8
+ version="0.5.13",
9
  author="Miles Cranmer",
10
  author_email="[email protected]",
11
  description="Simple and efficient symbolic regression",