MilesCranmer commited on
Commit
5b978f9
β€’
1 Parent(s): 6a4fa2c

Move JAX export to separate file

Browse files
pysr/__init__.py CHANGED
@@ -1,3 +1,3 @@
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
- from .export import sympy2jax
 
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
+ from .export_jax import sympy2jax
pysr/{export.py β†’ export_jax.py} RENAMED
@@ -75,7 +75,7 @@ def sympy2jaxtext(expr, parameters, symbols_in):
75
  else:
76
  return f'{_func}({", ".join(args)})'
77
 
78
- def sympy2jax(equation, symbols_in):
79
  """Returns a function f and its parameters;
80
  the function takes an input matrix, and a list of arguments:
81
  f(X, parameters)
@@ -147,8 +147,8 @@ def sympy2jax(equation, symbols_in):
147
  ```
148
  """
149
  parameters = []
150
- functional_form_text = sympy2jaxtext(equation, parameters, symbols_in)
151
- hash_string = 'A_' + str(abs(hash(str(equation) + str(symbols_in))))
152
  text = f"def {hash_string}(X, parameters):\n"
153
  text += " return "
154
  text += functional_form_text
 
75
  else:
76
  return f'{_func}({", ".join(args)})'
77
 
78
+ def sympy2jax(expression, symbols_in):
79
  """Returns a function f and its parameters;
80
  the function takes an input matrix, and a list of arguments:
81
  f(X, parameters)
 
147
  ```
148
  """
149
  parameters = []
150
+ functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
151
+ hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
152
  text = f"def {hash_string}(X, parameters):\n"
153
  text += " return "
154
  text += functional_form_text
pysr/sr.py CHANGED
@@ -13,7 +13,7 @@ import shutil
13
  from pathlib import Path
14
  from datetime import datetime
15
  import warnings
16
- from .export import sympy2jax
17
 
18
  global_equation_file = 'hall_of_fame.csv'
19
  global_n_features = None
 
13
  from pathlib import Path
14
  from datetime import datetime
15
  import warnings
16
+ from .export_jax import sympy2jax
17
 
18
  global_equation_file = 'hall_of_fame.csv'
19
  global_n_features = None