MilesCranmer commited on
Commit
84e4a47
1 Parent(s): e7ede78

Allow user to pass extra torch operators to pysr

Browse files
Files changed (1) hide show
  1. pysr/sr.py +12 -5
pysr/sr.py CHANGED
@@ -102,6 +102,8 @@ def pysr(X, y, weights=None,
102
  perturbationFactor=1.0,
103
  timeout=None,
104
  extra_sympy_mappings=None,
 
 
105
  equation_file=None,
106
  verbosity=1e9,
107
  progress=True,
@@ -336,6 +338,8 @@ def pysr(X, y, weights=None,
336
  weightSimplify=weightSimplify,
337
  constraints=constraints,
338
  extra_sympy_mappings=extra_sympy_mappings,
 
 
339
  julia_project=julia_project, loss=loss,
340
  output_jax_format=output_jax_format,
341
  output_torch_format=output_torch_format,
@@ -730,6 +734,7 @@ def run_feature_selection(X, y, select_k_features):
730
  def get_hof(equation_file=None, n_features=None, variable_names=None,
731
  extra_sympy_mappings=None, output_jax_format=False,
732
  output_torch_format=False,
 
733
  multioutput=None, nout=None, **kwargs):
734
  """Get the equations from a hall of fame file. If no arguments
735
  entered, the ones used previously from a call to PySR will be used."""
@@ -790,20 +795,22 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
790
  for i in range(len(output)):
791
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
792
  sympy_format.append(eqn)
 
 
 
 
 
793
  if output_jax_format:
794
  from .export_jax import sympy2jax
795
  func, params = sympy2jax(eqn, sympy_symbols)
796
  jax_format.append({'callable': func, 'parameters': params})
797
- <<<<<<< HEAD
798
 
799
- lambda_format.append(CallableEquation(sympy_symbols, eqn))
800
- =======
801
  if output_torch_format:
802
  from .export_torch import sympy2torch
803
  module = sympy2torch(eqn, sympy_symbols)
804
  torch_format.append(module)
805
- lambda_format.append(lambdify(sympy_symbols, eqn))
806
- >>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
807
  curMSE = output.loc[i, 'MSE']
808
  curComplexity = output.loc[i, 'Complexity']
809
 
 
102
  perturbationFactor=1.0,
103
  timeout=None,
104
  extra_sympy_mappings=None,
105
+ extra_torch_mappings=None,
106
+ extra_jax_mappings=None,
107
  equation_file=None,
108
  verbosity=1e9,
109
  progress=True,
 
338
  weightSimplify=weightSimplify,
339
  constraints=constraints,
340
  extra_sympy_mappings=extra_sympy_mappings,
341
+ extra_jax_mappings=extra_jax_mappings,
342
+ extra_torch_mappings=extra_torch_mappings,
343
  julia_project=julia_project, loss=loss,
344
  output_jax_format=output_jax_format,
345
  output_torch_format=output_torch_format,
 
734
  def get_hof(equation_file=None, n_features=None, variable_names=None,
735
  extra_sympy_mappings=None, output_jax_format=False,
736
  output_torch_format=False,
737
+ extra_jax_mappings=None, extra_torch_mappings=None,
738
  multioutput=None, nout=None, **kwargs):
739
  """Get the equations from a hall of fame file. If no arguments
740
  entered, the ones used previously from a call to PySR will be used."""
 
795
  for i in range(len(output)):
796
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
797
  sympy_format.append(eqn)
798
+
799
+ # Numpy:
800
+ lambda_format.append(CallableEquation(sympy_symbols, eqn))
801
+
802
+ # JAX:
803
  if output_jax_format:
804
  from .export_jax import sympy2jax
805
  func, params = sympy2jax(eqn, sympy_symbols)
806
  jax_format.append({'callable': func, 'parameters': params})
 
807
 
808
+ # Torch:
 
809
  if output_torch_format:
810
  from .export_torch import sympy2torch
811
  module = sympy2torch(eqn, sympy_symbols)
812
  torch_format.append(module)
813
+
 
814
  curMSE = output.loc[i, 'MSE']
815
  curComplexity = output.loc[i, 'Complexity']
816