MilesCranmer commited on
Commit
be36d4a
1 Parent(s): 25e0721

Add determinism-based scikit-learn tests back

Browse files
Files changed (1) hide show
  1. test/test.py +12 -34
test/test.py CHANGED
@@ -348,19 +348,16 @@ class TestMiscellaneous(unittest.TestCase):
348
  def test_scikit_learn_compatibility(self):
349
  """Test PySRRegressor compatibility with scikit-learn."""
350
  model = PySRRegressor(
351
- max_evals=10000, verbosity=0, progress=False
 
 
 
 
 
 
352
  ) # Return early.
353
 
354
- # TODO: Add deterministic option so that we can test these.
355
- # (would require backend changes, and procs=0 for serialism.)
356
  check_generator = check_estimator(model, generate_only=True)
357
- tests_requiring_determinism = [
358
- "check_regressors_int", # PySR is not deterministic, so fails this.
359
- "check_regressor_data_not_an_array",
360
- "check_supervised_y_2d",
361
- "check_regressors_int",
362
- "check_fit_idempotent",
363
- ]
364
  exception_messages = []
365
  for (_, check) in check_generator:
366
  try:
@@ -376,29 +373,10 @@ class TestMiscellaneous(unittest.TestCase):
376
  print("Passed", check.func.__name__)
377
  except Exception as e:
378
  error_message = str(e)
379
- failed_tolerance_check = "Not equal to tolerance" in error_message
380
-
381
- if (
382
- failed_tolerance_check
383
- and check.func.__name__ in tests_requiring_determinism
384
- ):
385
- # Skip test as PySR is not deterministic.
386
- print(
387
- "Failed",
388
- check.func.__name__,
389
- "which is an allowed failure, as the test requires determinism.",
390
- )
391
- else:
392
- exception_messages.append(
393
- f"{check.func.__name__}: {error_message}\n"
394
- )
395
- print("Failed", check.func.__name__, "with:")
396
- # Add a leading tab to error message, which
397
- # might be multi-line:
398
- print(
399
- "\n".join(
400
- [(" " * 4) + row for row in error_message.split("\n")]
401
- )
402
- )
403
  # If any checks failed don't let the test pass.
404
  self.assertEqual([], exception_messages)
 
348
  def test_scikit_learn_compatibility(self):
349
  """Test PySRRegressor compatibility with scikit-learn."""
350
  model = PySRRegressor(
351
+ max_evals=10000,
352
+ verbosity=0,
353
+ progress=False,
354
+ random_state=0,
355
+ deterministic=True,
356
+ procs=0,
357
+ multithreading=False,
358
  ) # Return early.
359
 
 
 
360
  check_generator = check_estimator(model, generate_only=True)
 
 
 
 
 
 
 
361
  exception_messages = []
362
  for (_, check) in check_generator:
363
  try:
 
373
  print("Passed", check.func.__name__)
374
  except Exception as e:
375
  error_message = str(e)
376
+ exception_messages.append(f"{check.func.__name__}: {error_message}\n")
377
+ print("Failed", check.func.__name__, "with:")
378
+ # Add a leading tab to error message, which
379
+ # might be multi-line:
380
+ print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  # If any checks failed don't let the test pass.
382
  self.assertEqual([], exception_messages)