Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
ea7ced4
1
Parent(s):
2f528dc
feat: automatically suggest related keywords
Browse files- pysr/sr.py +16 -3
- pysr/test/test.py +20 -0
pysr/sr.py
CHANGED
@@ -1,6 +1,8 @@
|
|
1 |
"""Define the PySRRegressor scikit-learn interface."""
|
2 |
|
3 |
import copy
|
|
|
|
|
4 |
import os
|
5 |
import pickle as pkl
|
6 |
import re
|
@@ -907,9 +909,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
907 |
FutureWarning,
|
908 |
)
|
909 |
else:
|
910 |
-
|
911 |
-
|
912 |
-
)
|
|
|
|
|
913 |
|
914 |
@classmethod
|
915 |
def from_file(
|
@@ -1991,6 +1995,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1991 |
|
1992 |
return self
|
1993 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1994 |
def refresh(self, checkpoint_file=None) -> None:
|
1995 |
"""
|
1996 |
Update self.equations_ with any new options passed.
|
|
|
1 |
"""Define the PySRRegressor scikit-learn interface."""
|
2 |
|
3 |
import copy
|
4 |
+
import difflib
|
5 |
+
import inspect
|
6 |
import os
|
7 |
import pickle as pkl
|
8 |
import re
|
|
|
909 |
FutureWarning,
|
910 |
)
|
911 |
else:
|
912 |
+
suggested_keywords = self._suggest_keywords(k)
|
913 |
+
err_msg = f"{k} is not a valid keyword argument for PySRRegressor."
|
914 |
+
if len(suggested_keywords) > 0:
|
915 |
+
err_msg += f" Did you mean {' or '.join(suggested_keywords)}?"
|
916 |
+
raise TypeError(err_msg)
|
917 |
|
918 |
@classmethod
|
919 |
def from_file(
|
|
|
1995 |
|
1996 |
return self
|
1997 |
|
1998 |
+
def _suggest_keywords(self, k: str) -> List[str]:
|
1999 |
+
valid_keywords = [
|
2000 |
+
param
|
2001 |
+
for param in inspect.signature(self.__init__).parameters
|
2002 |
+
if param not in ["self", "kwargs"]
|
2003 |
+
]
|
2004 |
+
suggestions = difflib.get_close_matches(k, valid_keywords, n=3)
|
2005 |
+
return suggestions
|
2006 |
+
|
2007 |
def refresh(self, checkpoint_file=None) -> None:
|
2008 |
"""
|
2009 |
Update self.equations_ with any new options passed.
|
pysr/test/test.py
CHANGED
@@ -804,6 +804,26 @@ class TestHelpMessages(unittest.TestCase):
|
|
804 |
model.get_best()
|
805 |
print("Failed", opt["kwargs"])
|
806 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
807 |
|
808 |
TRUE_PREAMBLE = "\n".join(
|
809 |
[
|
|
|
804 |
model.get_best()
|
805 |
print("Failed", opt["kwargs"])
|
806 |
|
807 |
+
def test_suggest_keywords(self):
|
808 |
+
model = PySRRegressor()
|
809 |
+
# Easy
|
810 |
+
self.assertEqual(model._suggest_keywords("loss_function"), ["loss_function"])
|
811 |
+
|
812 |
+
# More complex, and with error
|
813 |
+
with self.assertRaises(TypeError) as cm:
|
814 |
+
model = PySRRegressor(ncyclesperiterationn=5)
|
815 |
+
|
816 |
+
self.assertIn("ncyclesperiterationn is not a valid keyword", str(cm.exception))
|
817 |
+
self.assertIn("Did you mean", str(cm.exception))
|
818 |
+
self.assertIn("ncycles_per_iteration or", str(cm.exception))
|
819 |
+
self.assertIn("niteration", str(cm.exception))
|
820 |
+
|
821 |
+
# Farther matches (this might need to be changed)
|
822 |
+
with self.assertRaises(TypeError) as cm:
|
823 |
+
model = PySRRegressor(operators=["+", "-"])
|
824 |
+
|
825 |
+
self.assertIn("unary_operators or binary_operators", str(cm.exception))
|
826 |
+
|
827 |
|
828 |
TRUE_PREAMBLE = "\n".join(
|
829 |
[
|