MilesCranmer commited on
Commit
e0c68fc
1 Parent(s): a29e818

Propagate and check torch/jax mappings

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +14 -5
  2. pysr/sr.py +26 -2
pysr/export_jax.py CHANGED
@@ -51,7 +51,7 @@ _jnp_func_lookup = {
51
  }
52
 
53
 
54
- def sympy2jaxtext(expr, parameters, symbols_in):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
@@ -61,8 +61,15 @@ def sympy2jaxtext(expr, parameters, symbols_in):
61
  return (
62
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
  )
64
- _func = _jnp_func_lookup[expr.func]
65
- args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
 
 
 
 
 
 
 
66
  if _func == MUL:
67
  return " * ".join(["(" + arg + ")" for arg in args])
68
  if _func == ADD:
@@ -92,7 +99,7 @@ def _initialize_jax():
92
  jsp = _jsp
93
 
94
 
95
- def sympy2jax(expression, symbols_in, selection=None):
96
  """Returns a function f and its parameters;
97
  the function takes an input matrix, and a list of arguments:
98
  f(X, parameters)
@@ -170,7 +177,9 @@ def sympy2jax(expression, symbols_in, selection=None):
170
  global jsp
171
 
172
  parameters = []
173
- functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
 
 
174
  hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
175
  text = f"def {hash_string}(X, parameters):\n"
176
  if selection is not None:
 
51
  }
52
 
53
 
54
+ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
57
  return f"parameters[{len(parameters) - 1}]"
 
61
  return (
62
  f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
63
  )
64
+ if extra_jax_mappings is None:
65
+ extra_jax_mappings = {}
66
+ _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
67
+ args = [
68
+ sympy2jaxtext(
69
+ arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
70
+ )
71
+ for arg in expr.args
72
+ ]
73
  if _func == MUL:
74
  return " * ".join(["(" + arg + ")" for arg in args])
75
  if _func == ADD:
 
99
  jsp = _jsp
100
 
101
 
102
+ def sympy2jax(expression, symbols_in, selection=None, extra_jax_mappings=None):
103
  """Returns a function f and its parameters;
104
  the function takes an input matrix, and a list of arguments:
105
  f(X, parameters)
 
177
  global jsp
178
 
179
  parameters = []
180
+ functional_form_text = sympy2jaxtext(
181
+ expression, parameters, symbols_in, extra_jax_mappings
182
+ )
183
  hash_string = "A_" + str(abs(hash(str(expression) + str(symbols_in))))
184
  text = f"def {hash_string}(X, parameters):\n"
185
  if selection is not None:
pysr/sr.py CHANGED
@@ -289,6 +289,20 @@ def pysr(
289
  if len(variable_names) == 0:
290
  variable_names = [f"x{i}" for i in range(X.shape[1])]
291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
  use_custom_variable_names = len(variable_names) != 0
293
 
294
  _check_assertions(
@@ -996,14 +1010,24 @@ def get_hof(
996
  if output_jax_format:
997
  from .export_jax import sympy2jax
998
 
999
- func, params = sympy2jax(eqn, sympy_symbols, selection)
 
 
 
 
 
1000
  jax_format.append({"callable": func, "parameters": params})
1001
 
1002
  # Torch:
1003
  if output_torch_format:
1004
  from .export_torch import sympy2torch
1005
 
1006
- module = sympy2torch(eqn, sympy_symbols, selection=selection)
 
 
 
 
 
1007
  torch_format.append(module)
1008
 
1009
  curMSE = output.loc[i, "MSE"]
 
289
  if len(variable_names) == 0:
290
  variable_names = [f"x{i}" for i in range(X.shape[1])]
291
 
292
+ if extra_jax_mappings is not None:
293
+ for key, value in extra_jax_mappings:
294
+ if not isinstance(value, str):
295
+ raise NotImplementedError(
296
+ "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
297
+ )
298
+
299
+ if extra_torch_mappings is not None:
300
+ for key, value in extra_jax_mappings:
301
+ if not callable(value):
302
+ raise NotImplementedError(
303
+ "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
304
+ )
305
+
306
  use_custom_variable_names = len(variable_names) != 0
307
 
308
  _check_assertions(
 
1010
  if output_jax_format:
1011
  from .export_jax import sympy2jax
1012
 
1013
+ func, params = sympy2jax(
1014
+ eqn,
1015
+ sympy_symbols,
1016
+ selection=selection,
1017
+ extra_jax_mappings=extra_jax_mappings,
1018
+ )
1019
  jax_format.append({"callable": func, "parameters": params})
1020
 
1021
  # Torch:
1022
  if output_torch_format:
1023
  from .export_torch import sympy2torch
1024
 
1025
+ module = sympy2torch(
1026
+ eqn,
1027
+ sympy_symbols,
1028
+ selection=selection,
1029
+ extra_torch_mappings=extra_torch_mappings,
1030
+ )
1031
  torch_format.append(module)
1032
 
1033
  curMSE = output.loc[i, "MSE"]