MilesCranmer commited on
Commit
f340c5b
·
1 Parent(s): 162cbb5

Assert variable names are alphanumeric

Browse files
Files changed (2) hide show
  1. pysr/sr.py +7 -0
  2. test/test.py +24 -1
pysr/sr.py CHANGED
@@ -181,6 +181,13 @@ def _check_assertions(
181
  raise ValueError(
182
  f"Variable name {var_name} is already a function name."
183
  )
 
 
 
 
 
 
 
184
 
185
 
186
  def best(*args, **kwargs): # pragma: no cover
 
181
  raise ValueError(
182
  f"Variable name {var_name} is already a function name."
183
  )
184
+ # Check if alphanumeric only:
185
+ if not re.match(r"^[a-zA-Z0-9_]+$", var_name):
186
+ raise ValueError(
187
+ f"Invalid variable name {var_name}. "
188
+ "Only alphanumeric characters, numbers, "
189
+ "and underscores are allowed."
190
+ )
191
 
192
 
193
  def best(*args, **kwargs): # pragma: no cover
test/test.py CHANGED
@@ -572,8 +572,31 @@ class TestMiscellaneous(unittest.TestCase):
572
  model = PySRRegressor()
573
  X = np.random.randn(100, 2)
574
  y = np.random.randn(100)
575
- with self.assertRaises(ValueError):
576
  model.fit(X, y, variable_names=["x1", "N"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
577
 
578
  def test_pickle_with_temp_equation_file(self):
579
  """If we have a temporary equation file, unpickle the estimator."""
 
572
  model = PySRRegressor()
573
  X = np.random.randn(100, 2)
574
  y = np.random.randn(100)
575
+ with self.assertRaises(ValueError) as cm:
576
  model.fit(X, y, variable_names=["x1", "N"])
577
+ self.assertIn(
578
+ "Variable name",
579
+ str(cm.exception)
580
+ )
581
+
582
+ def test_bad_variable_names_fail(self):
583
+ model = PySRRegressor()
584
+ X = np.random.randn(100, 1)
585
+ y = np.random.randn(100)
586
+
587
+ with self.assertRaises(ValueError) as cm:
588
+ model.fit(X, y, variable_names=["Tr(Tij)"])
589
+ self.assertIn(
590
+ "Invalid variable name",
591
+ str(cm.exception)
592
+ )
593
+
594
+ with self.assertRaises(ValueError) as cm:
595
+ model.fit(X, y, variable_names=["f{c}"])
596
+ self.assertIn(
597
+ "Invalid variable name",
598
+ str(cm.exception)
599
+ )
600
 
601
  def test_pickle_with_temp_equation_file(self):
602
  """If we have a temporary equation file, unpickle the estimator."""