MilesCranmer commited on
Commit
e29a6da
·
1 Parent(s): d170500

Ensure that variable names are not sympy functions

Browse files
Files changed (2) hide show
  1. pysr/sr.py +6 -0
  2. test/test.py +7 -0
pysr/sr.py CHANGED
@@ -169,6 +169,12 @@ def _check_assertions(
169
  assert X.shape[0] == weights.shape[0]
170
  if use_custom_variable_names:
171
  assert len(variable_names) == X.shape[1]
 
 
 
 
 
 
172
 
173
 
174
  def best(*args, **kwargs): # pragma: no cover
 
169
  assert X.shape[0] == weights.shape[0]
170
  if use_custom_variable_names:
171
  assert len(variable_names) == X.shape[1]
172
+ # Check none of the variable names are function names:
173
+ for var_name in variable_names:
174
+ if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
175
+ raise ValueError(
176
+ f"Variable name {var_name} is already a function name."
177
+ )
178
 
179
 
180
  def best(*args, **kwargs): # pragma: no cover
test/test.py CHANGED
@@ -546,6 +546,13 @@ class TestMiscellaneous(unittest.TestCase):
546
  with self.assertRaises(ValueError):
547
  model.fit(X, y)
548
 
 
 
 
 
 
 
 
549
  def test_pickle_with_temp_equation_file(self):
550
  """If we have a temporary equation file, unpickle the estimator."""
551
  model = PySRRegressor(
 
546
  with self.assertRaises(ValueError):
547
  model.fit(X, y)
548
 
549
+ def test_sympy_function_fails_as_variable(self):
550
+ model = PySRRegressor()
551
+ X = np.random.randn(100, 2)
552
+ y = np.random.randn(100)
553
+ with self.assertRaises(ValueError):
554
+ model.fit(X, y, variable_names=["x1", "N"])
555
+
556
  def test_pickle_with_temp_equation_file(self):
557
  """If we have a temporary equation file, unpickle the estimator."""
558
  model = PySRRegressor(