tttc3 commited on
Commit
19d80b0
·
1 Parent(s): f570132

Fix to pass test_custom_operator

Browse files
Files changed (1) hide show
  1. pysr/sr.py +2 -2
pysr/sr.py CHANGED
@@ -1749,7 +1749,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1749
  extra_jax_mappings = self.extra_jax_mappings
1750
  extra_torch_mappings = self.extra_torch_mappings
1751
  if extra_jax_mappings is not None:
1752
- for value in self.extra_jax_mappings.values():
1753
  if not isinstance(value, str):
1754
  raise ValueError(
1755
  "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
@@ -1757,7 +1757,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1757
  else:
1758
  extra_jax_mappings = {}
1759
  if extra_torch_mappings is not None:
1760
- for value in self.extra_jax_mappings.values():
1761
  if not callable(value):
1762
  raise ValueError(
1763
  "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
 
1749
  extra_jax_mappings = self.extra_jax_mappings
1750
  extra_torch_mappings = self.extra_torch_mappings
1751
  if extra_jax_mappings is not None:
1752
+ for value in extra_jax_mappings.values():
1753
  if not isinstance(value, str):
1754
  raise ValueError(
1755
  "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
 
1757
  else:
1758
  extra_jax_mappings = {}
1759
  if extra_torch_mappings is not None:
1760
+ for value in extra_jax_mappings.values():
1761
  if not callable(value):
1762
  raise ValueError(
1763
  "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."