MilesCranmer commited on
Commit
7de1010
1 Parent(s): 6487b49

Ensure that equations are up-to-date when getting model representations

Browse files
Files changed (1) hide show
  1. pysr/sklearn.py +12 -2
pysr/sklearn.py CHANGED
@@ -93,22 +93,32 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
93
  )
94
  return self
95
 
 
 
 
 
 
96
  def predict(self, X):
 
97
  np_format = self.get_best()["lambda_format"]
98
  return np_format(X)
99
 
100
  def sympy(self):
 
101
  return self.get_best()["sympy_format"]
102
 
103
  def latex(self):
 
104
  return self.sympy().simplify()
105
 
106
  def jax(self):
107
- self.equations = get_hof(output_jax_format=True)
 
108
  return self.get_best()["jax_format"]
109
 
110
  def pytorch(self):
111
- self.equations = get_hof(output_torch_format=True)
 
112
  return self.get_best()["torch_format"]
113
 
114
 
 
93
  )
94
  return self
95
 
96
+ def refresh(self):
97
+ # Updates self.equations with any new options passed,
98
+ # such as extra_sympy_mappings.
99
+ self.equations = get_hof(**self.params)
100
+
101
  def predict(self, X):
102
+ self.refresh()
103
  np_format = self.get_best()["lambda_format"]
104
  return np_format(X)
105
 
106
  def sympy(self):
107
+ self.refresh()
108
  return self.get_best()["sympy_format"]
109
 
110
  def latex(self):
111
+ self.refresh()
112
  return self.sympy().simplify()
113
 
114
  def jax(self):
115
+ self.set_params(output_jax_format=True)
116
+ self.refresh()
117
  return self.get_best()["jax_format"]
118
 
119
  def pytorch(self):
120
+ self.set_params(output_torch_format=True)
121
+ self.refresh()
122
  return self.get_best()["torch_format"]
123
 
124