MilesCranmer commited on
Commit
0bf77e2
1 Parent(s): 4582e28

Autogenerate docstring for PySR Regressor

Browse files
Files changed (1) hide show
  1. pysr/sklearn.py +36 -2
pysr/sklearn.py CHANGED
@@ -1,5 +1,6 @@
1
  from pysr import pysr, best_row
2
  from sklearn.base import BaseEstimator
 
3
 
4
 
5
  class PySRRegressor(BaseEstimator):
@@ -42,12 +43,28 @@ class PySRRegressor(BaseEstimator):
42
  else:
43
  raise NotImplementedError
44
 
45
- def fit(self, X, y, weights=None):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  self.equations = pysr(
47
  X=X,
48
  y=y,
49
  weights=weights,
50
- **self.params,
 
51
  )
52
  return self
53
 
@@ -56,3 +73,20 @@ class PySRRegressor(BaseEstimator):
56
  np_format = equation_row["lambda_format"]
57
 
58
  return np_format(X)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from pysr import pysr, best_row
2
  from sklearn.base import BaseEstimator
3
+ import inspect
4
 
5
 
6
  class PySRRegressor(BaseEstimator):
 
43
  else:
44
  raise NotImplementedError
45
 
46
+ def fit(self, X, y, weights=None, variable_names=None):
47
+ """Search for equations to fit the dataset.
48
+
49
+ :param X: 2D array. Rows are examples, columns are features. If pandas DataFrame, the columns are used for variable names (so make sure they don't contain spaces).
50
+ :type X: np.ndarray/pandas.DataFrame
51
+ :param y: 1D array (rows are examples) or 2D array (rows are examples, columns are outputs). Putting in a 2D array will trigger a search for equations for each feature of y.
52
+ :type y: np.ndarray
53
+ :param weights: Optional. Same shape as y. Each element is how to weight the mean-square-error loss for that particular element of y.
54
+ :type weights: np.ndarray
55
+ :param variable_names: a list of names for the variables, other than "x0", "x1", etc.
56
+ :type variable_names: list
57
+ """
58
+ if variable_names is None:
59
+ if "variable_names" in self.params:
60
+ variable_names = self.params["variable_names"]
61
+
62
  self.equations = pysr(
63
  X=X,
64
  y=y,
65
  weights=weights,
66
+ variable_names=variable_names,
67
+ **{k: v for k, v in self.params.items() if k != "variable_names"},
68
  )
69
  return self
70
 
 
73
  np_format = equation_row["lambda_format"]
74
 
75
  return np_format(X)
76
+
77
+
78
+ # Add the docs from pysr() to PySRRegressor():
79
+
80
+ _pysr_docstring_split = []
81
+ _start_recording = False
82
+ for line in inspect.getdoc(pysr).split("\n"):
83
+ # Skip docs on "X" and "y"
84
+ if ":param binary_operators:" in line:
85
+ _start_recording = True
86
+ if ":returns:" in line:
87
+ _start_recording = False
88
+ if _start_recording:
89
+ _pysr_docstring_split.append(line)
90
+ _pysr_docstring = "\n\t".join(_pysr_docstring_split)
91
+
92
+ PySRRegressor.__init__.__doc__ += _pysr_docstring