MilesCranmer commited on
Commit
ea7ced4
1 Parent(s): 2f528dc

feat: automatically suggest related keywords

Browse files
Files changed (2) hide show
  1. pysr/sr.py +16 -3
  2. pysr/test/test.py +20 -0
pysr/sr.py CHANGED
@@ -1,6 +1,8 @@
1
  """Define the PySRRegressor scikit-learn interface."""
2
 
3
  import copy
 
 
4
  import os
5
  import pickle as pkl
6
  import re
@@ -907,9 +909,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
907
  FutureWarning,
908
  )
909
  else:
910
- raise TypeError(
911
- f"{k} is not a valid keyword argument for PySRRegressor."
912
- )
 
 
913
 
914
  @classmethod
915
  def from_file(
@@ -1991,6 +1995,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1991
 
1992
  return self
1993
 
 
 
 
 
 
 
 
 
 
1994
  def refresh(self, checkpoint_file=None) -> None:
1995
  """
1996
  Update self.equations_ with any new options passed.
 
1
  """Define the PySRRegressor scikit-learn interface."""
2
 
3
  import copy
4
+ import difflib
5
+ import inspect
6
  import os
7
  import pickle as pkl
8
  import re
 
909
  FutureWarning,
910
  )
911
  else:
912
+ suggested_keywords = self._suggest_keywords(k)
913
+ err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
914
+ if len(suggested_keywords) > 0:
915
+ err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
916
+ raise TypeError(err_msg)
917
 
918
  @classmethod
919
  def from_file(
 
1995
 
1996
  return self
1997
 
1998
+ def _suggest_keywords(self, k: str) -> List[str]:
1999
+ valid_keywords = [
2000
+ param
2001
+ for param in inspect.signature(self.__init__).parameters
2002
+ if param not in ["self", "kwargs"]
2003
+ ]
2004
+ suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
2005
+ return suggestions
2006
+
2007
  def refresh(self, checkpoint_file=None) -> None:
2008
  """
2009
  Update self.equations_ with any new options passed.
pysr/test/test.py CHANGED
@@ -804,6 +804,26 @@ class TestHelpMessages(unittest.TestCase):
804
  model.get_best()
805
  print("Failed", opt["kwargs"])
806
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
807
 
808
  TRUE_PREAMBLE = "\n".join(
809
  [
 
804
  model.get_best()
805
  print("Failed", opt["kwargs"])
806
 
807
+ def test_suggest_keywords(self):
808
+ model = PySRRegressor()
809
+ # Easy
810
+ self.assertEqual(model._suggest_keywords("loss_function"), ["loss_function"])
811
+
812
+ # More complex, and with error
813
+ with self.assertRaises(TypeError) as cm:
814
+ model = PySRRegressor(ncyclesperiterationn=5)
815
+
816
+ self.assertIn("ncyclesperiterationn is not a valid keyword", str(cm.exception))
817
+ self.assertIn("Did you mean", str(cm.exception))
818
+ self.assertIn("ncycles_per_iteration or", str(cm.exception))
819
+ self.assertIn("niteration", str(cm.exception))
820
+
821
+ # Farther matches (this might need to be changed)
822
+ with self.assertRaises(TypeError) as cm:
823
+ model = PySRRegressor(operators=["+", "-"])
824
+
825
+ self.assertIn("unary_operators or binary_operators", str(cm.exception))
826
+
827
 
828
  TRUE_PREAMBLE = "\n".join(
829
  [