MilesCranmer commited on
Commit
99fff5c
1 Parent(s): e0c68fc

Add export key error telling user to set function mappings

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +8 -1
  2. pysr/export_torch.py +8 -1
pysr/export_jax.py CHANGED
@@ -63,7 +63,14 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
63
  )
64
  if extra_jax_mappings is None:
65
  extra_jax_mappings = {}
66
- _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
 
 
 
 
 
 
 
67
  args = [
68
  sympy2jaxtext(
69
  arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
 
63
  )
64
  if extra_jax_mappings is None:
65
  extra_jax_mappings = {}
66
+ try:
67
+ _func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
68
+ except KeyError:
69
+ raise KeyError(
70
+ f"Function {expr.func} was not found in JAX function mappings."
71
+ "Please add it to extra_jax_mappings in the format, e.g., "
72
+ "{sympy.sqrt: 'jnp.sqrt'}."
73
+ )
74
  args = [
75
  sympy2jaxtext(
76
  arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
pysr/export_torch.py CHANGED
@@ -117,7 +117,14 @@ def _initialize_torch():
117
  self._torch_func = lambda value: value
118
  self._args = ((lambda memodict: memodict[expr.name]),)
119
  else:
120
- self._torch_func = _func_lookup[expr.func]
 
 
 
 
 
 
 
121
  args = []
122
  for arg in expr.args:
123
  try:
 
117
  self._torch_func = lambda value: value
118
  self._args = ((lambda memodict: memodict[expr.name]),)
119
  else:
120
+ try:
121
+ self._torch_func = _func_lookup[expr.func]
122
+ except KeyError:
123
+ raise KeyError(
124
+ f"Function {expr.func} was not found in Torch function mappings."
125
+ "Please add it to extra_torch_mappings in the format, e.g., "
126
+ "{sympy.sqrt: torch.sqrt}."
127
+ )
128
  args = []
129
  for arg in expr.args:
130
  try: