Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
84e4a47
1
Parent(s):
e7ede78
Allow user to pass extra torch operators to pysr
Browse files- pysr/sr.py +12 -5
pysr/sr.py
CHANGED
@@ -102,6 +102,8 @@ def pysr(X, y, weights=None,
|
|
102 |
perturbationFactor=1.0,
|
103 |
timeout=None,
|
104 |
extra_sympy_mappings=None,
|
|
|
|
|
105 |
equation_file=None,
|
106 |
verbosity=1e9,
|
107 |
progress=True,
|
@@ -336,6 +338,8 @@ def pysr(X, y, weights=None,
|
|
336 |
weightSimplify=weightSimplify,
|
337 |
constraints=constraints,
|
338 |
extra_sympy_mappings=extra_sympy_mappings,
|
|
|
|
|
339 |
julia_project=julia_project, loss=loss,
|
340 |
output_jax_format=output_jax_format,
|
341 |
output_torch_format=output_torch_format,
|
@@ -730,6 +734,7 @@ def run_feature_selection(X, y, select_k_features):
|
|
730 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
731 |
extra_sympy_mappings=None, output_jax_format=False,
|
732 |
output_torch_format=False,
|
|
|
733 |
multioutput=None, nout=None, **kwargs):
|
734 |
"""Get the equations from a hall of fame file. If no arguments
|
735 |
entered, the ones used previously from a call to PySR will be used."""
|
@@ -790,20 +795,22 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
|
|
790 |
for i in range(len(output)):
|
791 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
792 |
sympy_format.append(eqn)
|
|
|
|
|
|
|
|
|
|
|
793 |
if output_jax_format:
|
794 |
from .export_jax import sympy2jax
|
795 |
func, params = sympy2jax(eqn, sympy_symbols)
|
796 |
jax_format.append({'callable': func, 'parameters': params})
|
797 |
-
<<<<<<< HEAD
|
798 |
|
799 |
-
|
800 |
-
=======
|
801 |
if output_torch_format:
|
802 |
from .export_torch import sympy2torch
|
803 |
module = sympy2torch(eqn, sympy_symbols)
|
804 |
torch_format.append(module)
|
805 |
-
|
806 |
-
>>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
|
807 |
curMSE = output.loc[i, 'MSE']
|
808 |
curComplexity = output.loc[i, 'Complexity']
|
809 |
|
|
|
102 |
perturbationFactor=1.0,
|
103 |
timeout=None,
|
104 |
extra_sympy_mappings=None,
|
105 |
+
extra_torch_mappings=None,
|
106 |
+
extra_jax_mappings=None,
|
107 |
equation_file=None,
|
108 |
verbosity=1e9,
|
109 |
progress=True,
|
|
|
338 |
weightSimplify=weightSimplify,
|
339 |
constraints=constraints,
|
340 |
extra_sympy_mappings=extra_sympy_mappings,
|
341 |
+
extra_jax_mappings=extra_jax_mappings,
|
342 |
+
extra_torch_mappings=extra_torch_mappings,
|
343 |
julia_project=julia_project, loss=loss,
|
344 |
output_jax_format=output_jax_format,
|
345 |
output_torch_format=output_torch_format,
|
|
|
734 |
def get_hof(equation_file=None, n_features=None, variable_names=None,
|
735 |
extra_sympy_mappings=None, output_jax_format=False,
|
736 |
output_torch_format=False,
|
737 |
+
extra_jax_mappings=None, extra_torch_mappings=None,
|
738 |
multioutput=None, nout=None, **kwargs):
|
739 |
"""Get the equations from a hall of fame file. If no arguments
|
740 |
entered, the ones used previously from a call to PySR will be used."""
|
|
|
795 |
for i in range(len(output)):
|
796 |
eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
|
797 |
sympy_format.append(eqn)
|
798 |
+
|
799 |
+
# Numpy:
|
800 |
+
lambda_format.append(CallableEquation(sympy_symbols, eqn))
|
801 |
+
|
802 |
+
# JAX:
|
803 |
if output_jax_format:
|
804 |
from .export_jax import sympy2jax
|
805 |
func, params = sympy2jax(eqn, sympy_symbols)
|
806 |
jax_format.append({'callable': func, 'parameters': params})
|
|
|
807 |
|
808 |
+
# Torch:
|
|
|
809 |
if output_torch_format:
|
810 |
from .export_torch import sympy2torch
|
811 |
module = sympy2torch(eqn, sympy_symbols)
|
812 |
torch_format.append(module)
|
813 |
+
|
|
|
814 |
curMSE = output.loc[i, 'MSE']
|
815 |
curComplexity = output.loc[i, 'Complexity']
|
816 |
|