MilesCranmer commited on
Commit
68b3673
1 Parent(s): a06bfc4

Add torch format output; dont import jax/torch by default

Browse files
Files changed (4) hide show
  1. pysr/__init__.py +0 -2
  2. pysr/export_jax.py +47 -51
  3. pysr/export_torch.py +0 -2
  4. pysr/sr.py +21 -4
pysr/__init__.py CHANGED
@@ -1,4 +1,2 @@
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
- from .export_jax import sympy2jax
4
- from .export_torch import sympy2torch
 
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
 
 
pysr/export_jax.py CHANGED
@@ -2,60 +2,56 @@ import functools as ft
2
  import sympy
3
  import string
4
  import random
5
-
6
- try:
7
- import jax
8
- from jax import numpy as jnp
9
- from jax.scipy import special as jsp
10
 
11
  # Special since need to reduce arguments.
12
- MUL = 0
13
- ADD = 1
14
 
15
- _jnp_func_lookup = {
16
- sympy.Mul: MUL,
17
- sympy.Add: ADD,
18
- sympy.div: "jnp.div",
19
- sympy.Abs: "jnp.abs",
20
- sympy.sign: "jnp.sign",
21
- # Note: May raise error for ints.
22
- sympy.ceiling: "jnp.ceil",
23
- sympy.floor: "jnp.floor",
24
- sympy.log: "jnp.log",
25
- sympy.exp: "jnp.exp",
26
- sympy.sqrt: "jnp.sqrt",
27
- sympy.cos: "jnp.cos",
28
- sympy.acos: "jnp.acos",
29
- sympy.sin: "jnp.sin",
30
- sympy.asin: "jnp.asin",
31
- sympy.tan: "jnp.tan",
32
- sympy.atan: "jnp.atan",
33
- sympy.atan2: "jnp.atan2",
34
- # Note: Also may give NaN for complex results.
35
- sympy.cosh: "jnp.cosh",
36
- sympy.acosh: "jnp.acosh",
37
- sympy.sinh: "jnp.sinh",
38
- sympy.asinh: "jnp.asinh",
39
- sympy.tanh: "jnp.tanh",
40
- sympy.atanh: "jnp.atanh",
41
- sympy.Pow: "jnp.power",
42
- sympy.re: "jnp.real",
43
- sympy.im: "jnp.imag",
44
- sympy.arg: "jnp.angle",
45
- # Note: May raise error for ints and complexes
46
- sympy.erf: "jsp.erf",
47
- sympy.erfc: "jsp.erfc",
48
- sympy.LessThan: "jnp.less",
49
- sympy.GreaterThan: "jnp.greater",
50
- sympy.And: "jnp.logical_and",
51
- sympy.Or: "jnp.logical_or",
52
- sympy.Not: "jnp.logical_not",
53
- sympy.Max: "jnp.max",
54
- sympy.Min: "jnp.min",
55
- sympy.Mod: "jnp.mod",
56
- }
57
- except ImportError:
58
- ...
59
 
60
  def sympy2jaxtext(expr, parameters, symbols_in):
61
  if issubclass(expr.func, sympy.Float):
 
2
  import sympy
3
  import string
4
  import random
5
+ import jax
6
+ from jax import numpy as jnp
7
+ from jax.scipy import special as jsp
 
 
8
 
9
  # Special since need to reduce arguments.
10
+ MUL = 0
11
+ ADD = 1
12
 
13
+ _jnp_func_lookup = {
14
+ sympy.Mul: MUL,
15
+ sympy.Add: ADD,
16
+ sympy.div: "jnp.div",
17
+ sympy.Abs: "jnp.abs",
18
+ sympy.sign: "jnp.sign",
19
+ # Note: May raise error for ints.
20
+ sympy.ceiling: "jnp.ceil",
21
+ sympy.floor: "jnp.floor",
22
+ sympy.log: "jnp.log",
23
+ sympy.exp: "jnp.exp",
24
+ sympy.sqrt: "jnp.sqrt",
25
+ sympy.cos: "jnp.cos",
26
+ sympy.acos: "jnp.acos",
27
+ sympy.sin: "jnp.sin",
28
+ sympy.asin: "jnp.asin",
29
+ sympy.tan: "jnp.tan",
30
+ sympy.atan: "jnp.atan",
31
+ sympy.atan2: "jnp.atan2",
32
+ # Note: Also may give NaN for complex results.
33
+ sympy.cosh: "jnp.cosh",
34
+ sympy.acosh: "jnp.acosh",
35
+ sympy.sinh: "jnp.sinh",
36
+ sympy.asinh: "jnp.asinh",
37
+ sympy.tanh: "jnp.tanh",
38
+ sympy.atanh: "jnp.atanh",
39
+ sympy.Pow: "jnp.power",
40
+ sympy.re: "jnp.real",
41
+ sympy.im: "jnp.imag",
42
+ sympy.arg: "jnp.angle",
43
+ # Note: May raise error for ints and complexes
44
+ sympy.erf: "jsp.erf",
45
+ sympy.erfc: "jsp.erfc",
46
+ sympy.LessThan: "jnp.less",
47
+ sympy.GreaterThan: "jnp.greater",
48
+ sympy.And: "jnp.logical_and",
49
+ sympy.Or: "jnp.logical_or",
50
+ sympy.Not: "jnp.logical_not",
51
+ sympy.Max: "jnp.max",
52
+ sympy.Min: "jnp.min",
53
+ sympy.Mod: "jnp.mod",
54
+ }
 
 
55
 
56
  def sympy2jaxtext(expr, parameters, symbols_in):
57
  if issubclass(expr.func, sympy.Float):
pysr/export_torch.py CHANGED
@@ -8,7 +8,6 @@ import functools as ft
8
  import sympy
9
  import torch
10
 
11
-
12
  def _reduce(fn):
13
  def fn_(*args):
14
  return ft.reduce(fn, args)
@@ -67,7 +66,6 @@ _global_func_lookup = {
67
  sympy.Determinant: torch.det,
68
  }
69
 
70
-
71
  class _Node(torch.nn.Module):
72
  def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
73
  super().__init__(**kwargs)
 
8
  import sympy
9
  import torch
10
 
 
11
  def _reduce(fn):
12
  def fn_(*args):
13
  return ft.reduce(fn, args)
 
66
  sympy.Determinant: torch.det,
67
  }
68
 
 
69
  class _Node(torch.nn.Module):
70
  def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
71
  super().__init__(**kwargs)
pysr/sr.py CHANGED
@@ -13,8 +13,6 @@ import shutil
13
  from pathlib import Path
14
  from datetime import datetime
15
  import warnings
16
- from .export_jax import sympy2jax
17
- from .export_torch import sympy2torch
18
 
19
  global_equation_file = 'hall_of_fame.csv'
20
  global_n_features = None
@@ -125,11 +123,12 @@ def pysr(X, y, weights=None,
125
  update=True,
126
  temp_equation_file=False,
127
  output_jax_format=False,
 
128
  optimizer_algorithm="BFGS",
129
  optimizer_nrestarts=3,
130
  optimize_probability=1.0,
131
- optimizer_iterations=10,
132
- ):
133
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
134
  Note: most default parameters have been tuned over several example
135
  equations, but you should adjust `niterations`,
@@ -242,6 +241,8 @@ def pysr(X, y, weights=None,
242
  delete_tempfiles argument.
243
  :param output_jax_format: Whether to create a 'jax_format' column in the output,
244
  containing jax-callable functions and the default parameters in a jax array.
 
 
245
  :returns: pd.DataFrame or list, Results dataframe,
246
  giving complexity, MSE, and equations (as strings), as well as functional
247
  forms. If list, each element corresponds to a dataframe of equations
@@ -337,6 +338,7 @@ def pysr(X, y, weights=None,
337
  extra_sympy_mappings=extra_sympy_mappings,
338
  julia_project=julia_project, loss=loss,
339
  output_jax_format=output_jax_format,
 
340
  multioutput=multioutput, nout=nout)
341
 
342
  kwargs = {**_set_paths(tempdir), **kwargs}
@@ -727,6 +729,7 @@ def run_feature_selection(X, y, select_k_features):
727
 
728
  def get_hof(equation_file=None, n_features=None, variable_names=None,
729
  extra_sympy_mappings=None, output_jax_format=False,
 
730
  multioutput=None, nout=None, **kwargs):
731
  """Get the equations from a hall of fame file. If no arguments
732
  entered, the ones used previously from a call to PySR will be used."""
@@ -771,6 +774,8 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
771
  lambda_format = []
772
  if output_jax_format:
773
  jax_format = []
 
 
774
  use_custom_variable_names = (len(variable_names) != 0)
775
  local_sympy_mappings = {
776
  **extra_sympy_mappings,
@@ -786,10 +791,19 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
786
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
787
  sympy_format.append(eqn)
788
  if output_jax_format:
 
789
  func, params = sympy2jax(eqn, sympy_symbols)
790
  jax_format.append({'callable': func, 'parameters': params})
 
791
 
792
  lambda_format.append(CallableEquation(sympy_symbols, eqn))
 
 
 
 
 
 
 
793
  curMSE = output.loc[i, 'MSE']
794
  curComplexity = output.loc[i, 'Complexity']
795
 
@@ -809,6 +823,9 @@ def get_hof(equation_file=None, n_features=None, variable_names=None,
809
  if output_jax_format:
810
  output_cols += ['jax_format']
811
  output['jax_format'] = jax_format
 
 
 
812
 
813
  ret_outputs.append(output[output_cols])
814
 
 
13
  from pathlib import Path
14
  from datetime import datetime
15
  import warnings
 
 
16
 
17
  global_equation_file = 'hall_of_fame.csv'
18
  global_n_features = None
 
123
  update=True,
124
  temp_equation_file=False,
125
  output_jax_format=False,
126
+ output_torch_format=False,
127
  optimizer_algorithm="BFGS",
128
  optimizer_nrestarts=3,
129
  optimize_probability=1.0,
130
+ optimizer_iterations=10
131
+ ):
132
  """Run symbolic regression to fit f(X[i, :]) ~ y[i] for all i.
133
  Note: most default parameters have been tuned over several example
134
  equations, but you should adjust `niterations`,
 
241
  delete_tempfiles argument.
242
  :param output_jax_format: Whether to create a 'jax_format' column in the output,
243
  containing jax-callable functions and the default parameters in a jax array.
244
+ :param output_torch_format: Whether to create a 'torch_format' column in the output,
245
+ containing a torch module with trainable parameters.
246
  :returns: pd.DataFrame or list, Results dataframe,
247
  giving complexity, MSE, and equations (as strings), as well as functional
248
  forms. If list, each element corresponds to a dataframe of equations
 
338
  extra_sympy_mappings=extra_sympy_mappings,
339
  julia_project=julia_project, loss=loss,
340
  output_jax_format=output_jax_format,
341
+ output_torch_format=output_torch_format,
342
  multioutput=multioutput, nout=nout)
343
 
344
  kwargs = {**_set_paths(tempdir), **kwargs}
 
729
 
730
  def get_hof(equation_file=None, n_features=None, variable_names=None,
731
  extra_sympy_mappings=None, output_jax_format=False,
732
+ output_torch_format=False,
733
  multioutput=None, nout=None, **kwargs):
734
  """Get the equations from a hall of fame file. If no arguments
735
  entered, the ones used previously from a call to PySR will be used."""
 
774
  lambda_format = []
775
  if output_jax_format:
776
  jax_format = []
777
+ if output_torch_format:
778
+ torch_format = []
779
  use_custom_variable_names = (len(variable_names) != 0)
780
  local_sympy_mappings = {
781
  **extra_sympy_mappings,
 
791
  eqn = sympify(output.loc[i, 'Equation'], locals=local_sympy_mappings)
792
  sympy_format.append(eqn)
793
  if output_jax_format:
794
+ from .export_jax import sympy2jax
795
  func, params = sympy2jax(eqn, sympy_symbols)
796
  jax_format.append({'callable': func, 'parameters': params})
797
+ <<<<<<< HEAD
798
 
799
  lambda_format.append(CallableEquation(sympy_symbols, eqn))
800
+ =======
801
+ if output_torch_format:
802
+ from .export_torch import sympy2torch
803
+ func, params = sympy2torch(eqn, sympy_symbols)
804
+ torch_format.append({'callable': func, 'parameters': params})
805
+ lambda_format.append(lambdify(sympy_symbols, eqn))
806
+ >>>>>>> 6ba697f (Add torch format output; dont import jax/torch by default)
807
  curMSE = output.loc[i, 'MSE']
808
  curComplexity = output.loc[i, 'Complexity']
809
 
 
823
  if output_jax_format:
824
  output_cols += ['jax_format']
825
  output['jax_format'] = jax_format
826
+ if output_torch_format:
827
+ output_cols += ['torch_format']
828
+ output['torch_format'] = torch_format
829
 
830
  ret_outputs.append(output[output_cols])
831