Spaces:
Running
Running
tttc3
commited on
Commit
·
19d80b0
1
Parent(s):
f570132
Fix to pass test_custom_operator
Browse files- pysr/sr.py +2 -2
pysr/sr.py
CHANGED
@@ -1749,7 +1749,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1749 |
extra_jax_mappings = self.extra_jax_mappings
|
1750 |
extra_torch_mappings = self.extra_torch_mappings
|
1751 |
if extra_jax_mappings is not None:
|
1752 |
-
for value in
|
1753 |
if not isinstance(value, str):
|
1754 |
raise ValueError(
|
1755 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
@@ -1757,7 +1757,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1757 |
else:
|
1758 |
extra_jax_mappings = {}
|
1759 |
if extra_torch_mappings is not None:
|
1760 |
-
for value in
|
1761 |
if not callable(value):
|
1762 |
raise ValueError(
|
1763 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
|
|
1749 |
extra_jax_mappings = self.extra_jax_mappings
|
1750 |
extra_torch_mappings = self.extra_torch_mappings
|
1751 |
if extra_jax_mappings is not None:
|
1752 |
+
for value in extra_jax_mappings.values():
|
1753 |
if not isinstance(value, str):
|
1754 |
raise ValueError(
|
1755 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
|
|
1757 |
else:
|
1758 |
extra_jax_mappings = {}
|
1759 |
if extra_torch_mappings is not None:
|
1760 |
+
for value in extra_jax_mappings.values():
|
1761 |
if not callable(value):
|
1762 |
raise ValueError(
|
1763 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|