Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
99fff5c
1
Parent(s):
e0c68fc
Add export key error telling user to set function mappings
Browse files- pysr/export_jax.py +8 -1
- pysr/export_torch.py +8 -1
pysr/export_jax.py
CHANGED
@@ -63,7 +63,14 @@ def sympy2jaxtext(expr, parameters, symbols_in, extra_jax_mappings=None):
|
|
63 |
)
|
64 |
if extra_jax_mappings is None:
|
65 |
extra_jax_mappings = {}
|
66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
67 |
args = [
|
68 |
sympy2jaxtext(
|
69 |
arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
|
|
|
63 |
)
|
64 |
if extra_jax_mappings is None:
|
65 |
extra_jax_mappings = {}
|
66 |
+
try:
|
67 |
+
_func = {**_jnp_func_lookup, **extra_jax_mappings}[expr.func]
|
68 |
+
except KeyError:
|
69 |
+
raise KeyError(
|
70 |
+
f"Function {expr.func} was not found in JAX function mappings."
|
71 |
+
"Please add it to extra_jax_mappings in the format, e.g., "
|
72 |
+
"{sympy.sqrt: 'jnp.sqrt'}."
|
73 |
+
)
|
74 |
args = [
|
75 |
sympy2jaxtext(
|
76 |
arg, parameters, symbols_in, extra_jax_mappings=extra_jax_mappings
|
pysr/export_torch.py
CHANGED
@@ -117,7 +117,14 @@ def _initialize_torch():
|
|
117 |
self._torch_func = lambda value: value
|
118 |
self._args = ((lambda memodict: memodict[expr.name]),)
|
119 |
else:
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
args = []
|
122 |
for arg in expr.args:
|
123 |
try:
|
|
|
117 |
self._torch_func = lambda value: value
|
118 |
self._args = ((lambda memodict: memodict[expr.name]),)
|
119 |
else:
|
120 |
+
try:
|
121 |
+
self._torch_func = _func_lookup[expr.func]
|
122 |
+
except KeyError:
|
123 |
+
raise KeyError(
|
124 |
+
f"Function {expr.func} was not found in Torch function mappings."
|
125 |
+
"Please add it to extra_torch_mappings in the format, e.g., "
|
126 |
+
"{sympy.sqrt: torch.sqrt}."
|
127 |
+
)
|
128 |
args = []
|
129 |
for arg in expr.args:
|
130 |
try:
|