MilesCranmer commited on
Commit
e63cf2d
1 Parent(s): a0c6429

Fix case of no extra mappings for jax/torch

Browse files
Files changed (1) hide show
  1. pysr/sr.py +4 -0
pysr/sr.py CHANGED
@@ -546,6 +546,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
546
  raise NotImplementedError(
547
  "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
548
  )
 
 
549
 
550
  if extra_torch_mappings is not None:
551
  for value in extra_jax_mappings.values():
@@ -553,6 +555,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
553
  raise NotImplementedError(
554
  "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
555
  )
 
 
556
 
557
  if maxsize > 40:
558
  warnings.warn(
 
546
  raise NotImplementedError(
547
  "extra_jax_mappings must have keys that are strings! e.g., {sympy.sqrt: 'jnp.sqrt'}."
548
  )
549
+ else:
550
+ extra_jax_mappings = {}
551
 
552
  if extra_torch_mappings is not None:
553
  for value in extra_jax_mappings.values():
 
555
  raise NotImplementedError(
556
  "extra_torch_mappings must be callable functions! e.g., {sympy.sqrt: torch.sqrt}."
557
  )
558
+ else:
559
+ extra_torch_mappings = {}
560
 
561
  if maxsize > 40:
562
  warnings.warn(