Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
e63cf2d
1
Parent(s):
a0c6429
Fix case of no extra mappings for jax/torch
Browse files- pysr/sr.py +4 -0
pysr/sr.py
CHANGED
@@ -546,6 +546,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
546 |
raise NotImplementedError(
|
547 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
548 |
)
|
|
|
|
|
549 |
|
550 |
if extra_torch_mappings is not None:
|
551 |
for value in extra_jax_mappings.values():
|
@@ -553,6 +555,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
|
|
553 |
raise NotImplementedError(
|
554 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
555 |
)
|
|
|
|
|
556 |
|
557 |
if maxsize > 40:
|
558 |
warnings.warn(
|
|
|
546 |
raise NotImplementedError(
|
547 |
"extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
|
548 |
)
|
549 |
+
else:
|
550 |
+
extra_jax_mappings = {}
|
551 |
|
552 |
if extra_torch_mappings is not None:
|
553 |
for value in extra_jax_mappings.values():
|
|
|
555 |
raise NotImplementedError(
|
556 |
"extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
|
557 |
)
|
558 |
+
else:
|
559 |
+
extra_torch_mappings = {}
|
560 |
|
561 |
if maxsize > 40:
|
562 |
warnings.warn(
|