MilesCranmer commited on
Commit
6248183
1 Parent(s): ea7ced4

refactor: clean up use of `__init__` in keyword suggestion

Browse files
Files changed (2) hide show
  1. pysr/sr.py +11 -10
  2. pysr/test/test.py +9 -3
pysr/sr.py CHANGED
@@ -909,7 +909,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
909
  FutureWarning,
910
  )
911
  else:
912
- suggested_keywords = self._suggest_keywords(k)
913
  err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
914
  if len(suggested_keywords) > 0:
915
  err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
@@ -1995,15 +1995,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1995
 
1996
  return self
1997
 
1998
- def _suggest_keywords(self, k: str) -> List[str]:
1999
- valid_keywords = [
2000
- param
2001
- for param in inspect.signature(self.__init__).parameters
2002
- if param not in ["self", "kwargs"]
2003
- ]
2004
- suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
2005
- return suggestions
2006
-
2007
  def refresh(self, checkpoint_file=None) -> None:
2008
  """
2009
  Update self.equations_ with any new options passed.
@@ -2455,6 +2446,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
2455
  return with_preamble(table_string)
2456
 
2457
 
 
 
 
 
 
 
 
 
 
 
2458
  def idx_model_selection(equations: pd.DataFrame, model_selection: str):
2459
  """Select an expression and return its index."""
2460
  if model_selection == "accuracy":
 
909
  FutureWarning,
910
  )
911
  else:
912
+ suggested_keywords = _suggest_keywords(PySRRegressor, k)
913
  err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
914
  if len(suggested_keywords) > 0:
915
  err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
 
1995
 
1996
  return self
1997
 
 
 
 
 
 
 
 
 
 
1998
  def refresh(self, checkpoint_file=None) -> None:
1999
  """
2000
  Update self.equations_ with any new options passed.
 
2446
  return with_preamble(table_string)
2447
 
2448
 
2449
+ def _suggest_keywords(cls, k: str) -> List[str]:
2450
+ valid_keywords = [
2451
+ param
2452
+ for param in inspect.signature(cls.__init__).parameters
2453
+ if param not in ["self", "kwargs"]
2454
+ ]
2455
+ suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
2456
+ return suggestions
2457
+
2458
+
2459
  def idx_model_selection(equations: pd.DataFrame, model_selection: str):
2460
  """Select an expression and return its index."""
2461
  if model_selection == "accuracy":
pysr/test/test.py CHANGED
@@ -15,7 +15,12 @@ from .. import PySRRegressor, install, jl
15
  from ..export_latex import sympy2latex
16
  from ..feature_selection import _handle_feature_selection, run_feature_selection
17
  from ..julia_helpers import init_julia
18
- from ..sr import _check_assertions, _process_constraints, idx_model_selection
 
 
 
 
 
19
  from ..utils import _csv_filename_to_pkl_filename
20
  from .params import (
21
  DEFAULT_NCYCLES,
@@ -805,9 +810,10 @@ class TestHelpMessages(unittest.TestCase):
805
  print("Failed", opt["kwargs"])
806
 
807
  def test_suggest_keywords(self):
808
- model = PySRRegressor()
809
  # Easy
810
- self.assertEqual(model._suggest_keywords("loss_function"), ["loss_function"])
 
 
811
 
812
  # More complex, and with error
813
  with self.assertRaises(TypeError) as cm:
 
15
  from ..export_latex import sympy2latex
16
  from ..feature_selection import _handle_feature_selection, run_feature_selection
17
  from ..julia_helpers import init_julia
18
+ from ..sr import (
19
+ _check_assertions,
20
+ _process_constraints,
21
+ _suggest_keywords,
22
+ idx_model_selection,
23
+ )
24
  from ..utils import _csv_filename_to_pkl_filename
25
  from .params import (
26
  DEFAULT_NCYCLES,
 
810
  print("Failed", opt["kwargs"])
811
 
812
  def test_suggest_keywords(self):
 
813
  # Easy
814
+ self.assertEqual(
815
+ _suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
816
+ )
817
 
818
  # More complex, and with error
819
  with self.assertRaises(TypeError) as cm: