MilesCranmer commited on
Commit
96d6ea9
2 Parent(s): f653388 6088859

Merge pull request #620 from MilesCranmer/autocorrect-kwarg

Browse files
Files changed (2) hide show
  1. pysr/sr.py +21 -5
  2. pysr/test/test.py +123 -90
pysr/sr.py CHANGED
@@ -1,6 +1,8 @@
1
  """Define the PySRRegressor scikit-learn interface."""
2
 
3
  import copy
 
 
4
  import os
5
  import pickle as pkl
6
  import re
@@ -900,15 +902,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
900
  updated_kwarg_name = DEPRECATED_KWARGS[k]
901
  setattr(self, updated_kwarg_name, v)
902
  warnings.warn(
903
- f"{k} has been renamed to {updated_kwarg_name} in PySRRegressor. "
904
  "Please use that instead.",
905
  FutureWarning,
906
  )
907
  # Handle kwargs that have been moved to the fit method
908
  elif k in ["weights", "variable_names", "Xresampled"]:
909
  warnings.warn(
910
- f"{k} is a data dependant parameter so should be passed when fit is called. "
911
- f"Ignoring parameter; please pass {k} during the call to fit instead.",
912
  FutureWarning,
913
  )
914
  elif k == "julia_project":
@@ -925,9 +927,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
925
  FutureWarning,
926
  )
927
  else:
928
- raise TypeError(
929
- f"{k} is not a valid keyword argument for PySRRegressor."
 
930
  )
 
 
 
931
 
932
  @classmethod
933
  def from_file(
@@ -2459,6 +2465,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2459
  return with_preamble(table_string)
2460
 
2461
 
 
 
 
 
 
 
 
 
 
 
2462
  def idx_model_selection(equations: pd.DataFrame, model_selection: str):
2463
  """Select an expression and return its index."""
2464
  if model_selection == "accuracy":
 
1
  """Define the PySRRegressor scikit-learn interface."""
2
 
3
  import copy
4
+ import difflib
5
+ import inspect
6
  import os
7
  import pickle as pkl
8
  import re
 
902
  updated_kwarg_name = DEPRECATED_KWARGS[k]
903
  setattr(self, updated_kwarg_name, v)
904
  warnings.warn(
905
+ f"`{k}` has been renamed to `{updated_kwarg_name}` in PySRRegressor. "
906
  "Please use that instead.",
907
  FutureWarning,
908
  )
909
  # Handle kwargs that have been moved to the fit method
910
  elif k in ["weights", "variable_names", "Xresampled"]:
911
  warnings.warn(
912
+ f"`{k}` is a data-dependent parameter and should be passed when fit is called. "
913
+ f"Ignoring parameter; please pass `{k}` during the call to fit instead.",
914
  FutureWarning,
915
  )
916
  elif k == "julia_project":
 
927
  FutureWarning,
928
  )
929
  else:
930
+ suggested_keywords = _suggest_keywords(PySRRegressor, k)
931
+ err_msg = (
932
+ f"`{k}` is not a valid keyword argument for PySRRegressor."
933
  )
934
+ if len(suggested_keywords) > 0:
935
+ err_msg += f" Did you mean {', '.join(map(lambda s: f'`{s}`', suggested_keywords))}?"
936
+ raise TypeError(err_msg)
937
 
938
  @classmethod
939
  def from_file(
 
2465
  return with_preamble(table_string)
2466
 
2467
 
2468
+ def _suggest_keywords(cls, k: str) -> List[str]:
2469
+ valid_keywords = [
2470
+ param
2471
+ for param in inspect.signature(cls.__init__).parameters
2472
+ if param not in ["self", "kwargs"]
2473
+ ]
2474
+ suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
2475
+ return suggestions
2476
+
2477
+
2478
  def idx_model_selection(equations: pd.DataFrame, model_selection: str):
2479
  """Select an expression and return its index."""
2480
  if model_selection == "accuracy":
pysr/test/test.py CHANGED
@@ -15,7 +15,12 @@ from .. import PySRRegressor, install, jl
15
  from ..export_latex import sympy2latex
16
  from ..feature_selection import _handle_feature_selection, run_feature_selection
17
  from ..julia_helpers import init_julia
18
- from ..sr import _check_assertions, _process_constraints, idx_model_selection
 
 
 
 
 
19
  from ..utils import _csv_filename_to_pkl_filename
20
  from .params import (
21
  DEFAULT_NCYCLES,
@@ -573,6 +578,105 @@ class TestMiscellaneous(unittest.TestCase):
573
  test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
574
  self.assertEqual(test_pkl_file, str(expected_pkl_file))
575
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
576
  def test_deprecation(self):
577
  """Ensure that deprecation works as expected.
578
 
@@ -715,100 +819,28 @@ class TestMiscellaneous(unittest.TestCase):
715
  model.get_best()
716
  print("Failed", opt["kwargs"])
717
 
718
- def test_pickle_with_temp_equation_file(self):
719
- """If we have a temporary equation file, unpickle the estimator."""
720
- model = PySRRegressor(
721
- populations=int(1 + DEFAULT_POPULATIONS / 5),
722
- temp_equation_file=True,
723
- procs=0,
724
- multithreading=False,
725
  )
726
- nout = 3
727
- X = np.random.randn(100, 2)
728
- y = np.random.randn(100, nout)
729
- model.fit(X, y)
730
- contents = model.equation_file_contents_.copy()
731
-
732
- y_predictions = model.predict(X)
733
-
734
- equation_file_base = model.equation_file_
735
- for i in range(1, nout + 1):
736
- assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
737
-
738
- with tempfile.NamedTemporaryFile() as pickle_file:
739
- pkl.dump(model, pickle_file)
740
- pickle_file.seek(0)
741
- model2 = pkl.load(pickle_file)
742
-
743
- contents2 = model2.equation_file_contents_
744
- cols_to_check = ["equation", "loss", "complexity"]
745
- for frame1, frame2 in zip(contents, contents2):
746
- pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
747
 
748
- y_predictions2 = model2.predict(X)
749
- np.testing.assert_array_equal(y_predictions, y_predictions2)
 
750
 
751
- def test_scikit_learn_compatibility(self):
752
- """Test PySRRegressor compatibility with scikit-learn."""
753
- model = PySRRegressor(
754
- niterations=int(1 + DEFAULT_NITERATIONS / 10),
755
- populations=int(1 + DEFAULT_POPULATIONS / 3),
756
- ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
757
- verbosity=0,
758
- progress=False,
759
- random_state=0,
760
- deterministic=True, # Deterministic as tests require this.
761
- procs=0,
762
- multithreading=False,
763
- warm_start=False,
764
- temp_equation_file=True,
765
- ) # Return early.
766
-
767
- check_generator = check_estimator(model, generate_only=True)
768
- exception_messages = []
769
- for _, check in check_generator:
770
- if check.func.__name__ == "check_complex_data":
771
- # We can use complex data, so avoid this check.
772
- continue
773
- try:
774
- with warnings.catch_warnings():
775
- warnings.simplefilter("ignore")
776
- check(model)
777
- print("Passed", check.func.__name__)
778
- except Exception:
779
- error_message = str(traceback.format_exc())
780
- exception_messages.append(
781
- f"{check.func.__name__}:\n" + error_message + "\n"
782
- )
783
- print("Failed", check.func.__name__, "with:")
784
- # Add a leading tab to error message, which
785
- # might be multi-line:
786
- print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
787
- # If any checks failed don't let the test pass.
788
- self.assertEqual(len(exception_messages), 0)
789
-
790
- def test_param_groupings(self):
791
- """Test that param_groupings are complete"""
792
- param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
793
- if not param_groupings_file.exists():
794
- return
795
-
796
- # Read the file, discarding lines ending in ":",
797
- # and removing leading "\s*-\s*":
798
- params = []
799
- with open(param_groupings_file, "r") as f:
800
- for line in f.readlines():
801
- if line.strip().endswith(":"):
802
- continue
803
- if line.strip().startswith("-"):
804
- params.append(line.strip()[1:].strip())
805
 
806
- regressor_params = [
807
- p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
808
- ]
809
 
810
- # Check the sets are equal:
811
- self.assertSetEqual(set(params), set(regressor_params))
812
 
813
 
814
  TRUE_PREAMBLE = "\n".join(
@@ -1158,6 +1190,7 @@ def runtests(just_tests=False):
1158
  TestBest,
1159
  TestFeatureSelection,
1160
  TestMiscellaneous,
 
1161
  TestLaTeXTable,
1162
  TestDimensionalConstraints,
1163
  ]
 
15
  from ..export_latex import sympy2latex
16
  from ..feature_selection import _handle_feature_selection, run_feature_selection
17
  from ..julia_helpers import init_julia
18
+ from ..sr import (
19
+ _check_assertions,
20
+ _process_constraints,
21
+ _suggest_keywords,
22
+ idx_model_selection,
23
+ )
24
  from ..utils import _csv_filename_to_pkl_filename
25
  from .params import (
26
  DEFAULT_NCYCLES,
 
578
  test_pkl_file = _csv_filename_to_pkl_filename(str(equation_file))
579
  self.assertEqual(test_pkl_file, str(expected_pkl_file))
580
 
581
+ def test_pickle_with_temp_equation_file(self):
582
+ """If we have a temporary equation file, unpickle the estimator."""
583
+ model = PySRRegressor(
584
+ populations=int(1 + DEFAULT_POPULATIONS / 5),
585
+ temp_equation_file=True,
586
+ procs=0,
587
+ multithreading=False,
588
+ )
589
+ nout = 3
590
+ X = np.random.randn(100, 2)
591
+ y = np.random.randn(100, nout)
592
+ model.fit(X, y)
593
+ contents = model.equation_file_contents_.copy()
594
+
595
+ y_predictions = model.predict(X)
596
+
597
+ equation_file_base = model.equation_file_
598
+ for i in range(1, nout + 1):
599
+ assert not os.path.exists(str(equation_file_base) + f".out{i}.bkup")
600
+
601
+ with tempfile.NamedTemporaryFile() as pickle_file:
602
+ pkl.dump(model, pickle_file)
603
+ pickle_file.seek(0)
604
+ model2 = pkl.load(pickle_file)
605
+
606
+ contents2 = model2.equation_file_contents_
607
+ cols_to_check = ["equation", "loss", "complexity"]
608
+ for frame1, frame2 in zip(contents, contents2):
609
+ pd.testing.assert_frame_equal(frame1[cols_to_check], frame2[cols_to_check])
610
+
611
+ y_predictions2 = model2.predict(X)
612
+ np.testing.assert_array_equal(y_predictions, y_predictions2)
613
+
614
+ def test_scikit_learn_compatibility(self):
615
+ """Test PySRRegressor compatibility with scikit-learn."""
616
+ model = PySRRegressor(
617
+ niterations=int(1 + DEFAULT_NITERATIONS / 10),
618
+ populations=int(1 + DEFAULT_POPULATIONS / 3),
619
+ ncycles_per_iteration=int(2 + DEFAULT_NCYCLES / 10),
620
+ verbosity=0,
621
+ progress=False,
622
+ random_state=0,
623
+ deterministic=True, # Deterministic as tests require this.
624
+ procs=0,
625
+ multithreading=False,
626
+ warm_start=False,
627
+ temp_equation_file=True,
628
+ ) # Return early.
629
+
630
+ check_generator = check_estimator(model, generate_only=True)
631
+ exception_messages = []
632
+ for _, check in check_generator:
633
+ if check.func.__name__ == "check_complex_data":
634
+ # We can use complex data, so avoid this check.
635
+ continue
636
+ try:
637
+ with warnings.catch_warnings():
638
+ warnings.simplefilter("ignore")
639
+ check(model)
640
+ print("Passed", check.func.__name__)
641
+ except Exception:
642
+ error_message = str(traceback.format_exc())
643
+ exception_messages.append(
644
+ f"{check.func.__name__}:\n" + error_message + "\n"
645
+ )
646
+ print("Failed", check.func.__name__, "with:")
647
+ # Add a leading tab to error message, which
648
+ # might be multi-line:
649
+ print("\n".join([(" " * 4) + row for row in error_message.split("\n")]))
650
+ # If any checks failed don't let the test pass.
651
+ self.assertEqual(len(exception_messages), 0)
652
+
653
+ def test_param_groupings(self):
654
+ """Test that param_groupings are complete"""
655
+ param_groupings_file = Path(__file__).parent.parent / "param_groupings.yml"
656
+ if not param_groupings_file.exists():
657
+ return
658
+
659
+ # Read the file, discarding lines ending in ":",
660
+ # and removing leading "\s*-\s*":
661
+ params = []
662
+ with open(param_groupings_file, "r") as f:
663
+ for line in f.readlines():
664
+ if line.strip().endswith(":"):
665
+ continue
666
+ if line.strip().startswith("-"):
667
+ params.append(line.strip()[1:].strip())
668
+
669
+ regressor_params = [
670
+ p for p in DEFAULT_PARAMS.keys() if p not in ["self", "kwargs"]
671
+ ]
672
+
673
+ # Check the sets are equal:
674
+ self.assertSetEqual(set(params), set(regressor_params))
675
+
676
+
677
+ class TestHelpMessages(unittest.TestCase):
678
+ """Test user help messages."""
679
+
680
  def test_deprecation(self):
681
  """Ensure that deprecation works as expected.
682
 
 
819
  model.get_best()
820
  print("Failed", opt["kwargs"])
821
 
822
+ def test_suggest_keywords(self):
823
+ # Easy
824
+ self.assertEqual(
825
+ _suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
 
 
 
826
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
827
 
828
+ # More complex, and with error
829
+ with self.assertRaises(TypeError) as cm:
830
+ model = PySRRegressor(ncyclesperiterationn=5)
831
 
832
+ self.assertIn(
833
+ "`ncyclesperiterationn` is not a valid keyword", str(cm.exception)
834
+ )
835
+ self.assertIn("Did you mean", str(cm.exception))
836
+ self.assertIn("`ncycles_per_iteration`, ", str(cm.exception))
837
+ self.assertIn("`niterations`", str(cm.exception))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
838
 
839
+ # Farther matches (this might need to be changed)
840
+ with self.assertRaises(TypeError) as cm:
841
+ model = PySRRegressor(operators=["+", "-"])
842
 
843
+ self.assertIn("`unary_operators`, `binary_operators`", str(cm.exception))
 
844
 
845
 
846
  TRUE_PREAMBLE = "\n".join(
 
1190
  TestBest,
1191
  TestFeatureSelection,
1192
  TestMiscellaneous,
1193
+ TestHelpMessages,
1194
  TestLaTeXTable,
1195
  TestDimensionalConstraints,
1196
  ]