MilesCranmer commited on
Commit
045bdb1
·
1 Parent(s): af8ab17

Add tests for determinism warnings

Browse files
Files changed (1) hide show
  1. test/test.py +19 -0
test/test.py CHANGED
@@ -358,6 +358,25 @@ class TestMiscellaneous(unittest.TestCase):
358
  model.fit(X, y)
359
  self.assertIn("with 10 features or more", str(context.exception))
360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
361
  def test_scikit_learn_compatibility(self):
362
  """Test PySRRegressor compatibility with scikit-learn."""
363
  model = PySRRegressor(
 
358
  model.fit(X, y)
359
  self.assertIn("with 10 features or more", str(context.exception))
360
 
361
+ def test_deterministic_warnings(self):
362
+ """Ensure that warnings are given for determinism"""
363
+ model = PySRRegressor(random_state=0)
364
+ X = np.random.randn(100, 2)
365
+ y = np.random.randn(100)
366
+ with warnings.catch_warnings():
367
+ warnings.simplefilter("error")
368
+ with self.assertRaises(Exception) as context:
369
+ model.fit(X, y)
370
+ self.assertIn("`deterministic`", str(context.exception))
371
+
372
+ def test_deterministic_errors(self):
373
+ """Setting deterministic without random_state should error"""
374
+ model = PySRRegressor(deterministic=True)
375
+ X = np.random.randn(100, 2)
376
+ y = np.random.randn(100)
377
+ with self.assertRaises(ValueError):
378
+ model.fit(X, y)
379
+
380
  def test_scikit_learn_compatibility(self):
381
  """Test PySRRegressor compatibility with scikit-learn."""
382
  model = PySRRegressor(