Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
6248183
1
Parent(s):
ea7ced4
refactor: clean up use of `__init__` in keyword suggestion
Browse files- pysr/sr.py +11 -10
- pysr/test/test.py +9 -3
pysr/sr.py
CHANGED
@@ -909,7 +909,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
909 |
FutureWarning,
|
910 |
)
|
911 |
else:
|
912 |
-
suggested_keywords =
|
913 |
err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
|
914 |
if len(suggested_keywords) > 0:
|
915 |
err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
|
@@ -1995,15 +1995,6 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1995 |
|
1996 |
return self
|
1997 |
|
1998 |
-
def _suggest_keywords(self, k: str) -> List[str]:
|
1999 |
-
valid_keywords = [
|
2000 |
-
param
|
2001 |
-
for param in inspect.signature(self.__init__).parameters
|
2002 |
-
if param not in ["self", "kwargs"]
|
2003 |
-
]
|
2004 |
-
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
|
2005 |
-
return suggestions
|
2006 |
-
|
2007 |
def refresh(self, checkpoint_file=None) -> None:
|
2008 |
"""
|
2009 |
Update self.equations_ with any new options passed.
|
@@ -2455,6 +2446,16 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2455 |
return with_preamble(table_string)
|
2456 |
|
2457 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2458 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
|
2459 |
"""Select an expression and return its index."""
|
2460 |
if model_selection == "accuracy":
|
|
|
909 |
FutureWarning,
|
910 |
)
|
911 |
else:
|
912 |
+
suggested_keywords = _suggest_keywords(PySRRegressor, k)
|
913 |
err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
|
914 |
if len(suggested_keywords) > 0:
|
915 |
err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
|
|
|
1995 |
|
1996 |
return self
|
1997 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1998 |
def refresh(self, checkpoint_file=None) -> None:
|
1999 |
"""
|
2000 |
Update self.equations_ with any new options passed.
|
|
|
2446 |
return with_preamble(table_string)
|
2447 |
|
2448 |
|
2449 |
+
def _suggest_keywords(cls, k: str) -> List[str]:
|
2450 |
+
valid_keywords = [
|
2451 |
+
param
|
2452 |
+
for param in inspect.signature(cls.__init__).parameters
|
2453 |
+
if param not in ["self", "kwargs"]
|
2454 |
+
]
|
2455 |
+
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
|
2456 |
+
return suggestions
|
2457 |
+
|
2458 |
+
|
2459 |
def idx_model_selection(equations: pd.DataFrame, model_selection: str):
|
2460 |
"""Select an expression and return its index."""
|
2461 |
if model_selection == "accuracy":
|
pysr/test/test.py
CHANGED
@@ -15,7 +15,12 @@ from .. import PySRRegressor, install, jl
|
|
15 |
from ..export_latex import sympy2latex
|
16 |
from ..feature_selection import _handle_feature_selection, run_feature_selection
|
17 |
from ..julia_helpers import init_julia
|
18 |
-
from ..sr import
|
|
|
|
|
|
|
|
|
|
|
19 |
from ..utils import _csv_filename_to_pkl_filename
|
20 |
from .params import (
|
21 |
DEFAULT_NCYCLES,
|
@@ -805,9 +810,10 @@ class TestHelpMessages(unittest.TestCase):
|
|
805 |
print("Failed", opt["kwargs"])
|
806 |
|
807 |
def test_suggest_keywords(self):
|
808 |
-
model = PySRRegressor()
|
809 |
# Easy
|
810 |
-
self.assertEqual(
|
|
|
|
|
811 |
|
812 |
# More complex, and with error
|
813 |
with self.assertRaises(TypeError) as cm:
|
|
|
15 |
from ..export_latex import sympy2latex
|
16 |
from ..feature_selection import _handle_feature_selection, run_feature_selection
|
17 |
from ..julia_helpers import init_julia
|
18 |
+
from ..sr import (
|
19 |
+
_check_assertions,
|
20 |
+
_process_constraints,
|
21 |
+
_suggest_keywords,
|
22 |
+
idx_model_selection,
|
23 |
+
)
|
24 |
from ..utils import _csv_filename_to_pkl_filename
|
25 |
from .params import (
|
26 |
DEFAULT_NCYCLES,
|
|
|
810 |
print("Failed", opt["kwargs"])
|
811 |
|
812 |
def test_suggest_keywords(self):
|
|
|
813 |
# Easy
|
814 |
+
self.assertEqual(
|
815 |
+
_suggest_keywords(PySRRegressor, "loss_function"), ["loss_function"]
|
816 |
+
)
|
817 |
|
818 |
# More complex, and with error
|
819 |
with self.assertRaises(TypeError) as cm:
|