deepsource-autofix[bot] commited on
Commit
1662e82
·
1 Parent(s): b6f3a08

Format code with black

Browse files
Files changed (2) hide show
  1. pysr/sr.py +1 -1
  2. test/test.py +3 -3
pysr/sr.py CHANGED
@@ -798,7 +798,7 @@ class PySRRegressor(BaseEstimator, RegressorMixin):
798
  if self.multioutput:
799
  return [eq["jax_format"] for eq in best]
800
  return best["jax_format"]
801
-
802
  def pytorch(self):
803
  self.set_params(output_torch_format=True)
804
  self.refresh()
 
798
  if self.multioutput:
799
  return [eq["jax_format"] for eq in best]
800
  return best["jax_format"]
801
+
802
  def pytorch(self):
803
  self.set_params(output_torch_format=True)
804
  self.refresh()
test/test.py CHANGED
@@ -154,7 +154,7 @@ class TestPipeline(unittest.TestCase):
154
  self.assertIn("T", model.latex())
155
  self.assertIn("x", model.latex())
156
  self.assertLessEqual(model.get_best()["loss"], 1e-2)
157
- fn = model.get_best()['lambda_format']
158
  self.assertListEqual(list(sorted(fn._selection)), [0, 1])
159
  X2 = pd.DataFrame(
160
  {
@@ -202,7 +202,7 @@ class TestBest(unittest.TestCase):
202
  def test_best_lambda(self):
203
  X = np.random.randn(10, 2)
204
  y = np.cos(X[:, 0]) ** 2
205
- for f in [self.model.predict, self.equations.iloc[-1]['lambda_format']]:
206
  np.testing.assert_almost_equal(f(X), y, decimal=4)
207
 
208
 
@@ -231,4 +231,4 @@ class TestFeatureSelection(unittest.TestCase):
231
  self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
232
  np.testing.assert_array_equal(
233
  np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)
234
- )
 
154
  self.assertIn("T", model.latex())
155
  self.assertIn("x", model.latex())
156
  self.assertLessEqual(model.get_best()["loss"], 1e-2)
157
+ fn = model.get_best()["lambda_format"]
158
  self.assertListEqual(list(sorted(fn._selection)), [0, 1])
159
  X2 = pd.DataFrame(
160
  {
 
202
  def test_best_lambda(self):
203
  X = np.random.randn(10, 2)
204
  y = np.cos(X[:, 0]) ** 2
205
+ for f in [self.model.predict, self.equations.iloc[-1]["lambda_format"]]:
206
  np.testing.assert_almost_equal(f(X), y, decimal=4)
207
 
208
 
 
231
  self.assertEqual(set(selected_var_names), set("x2 x3".split(" ")))
232
  np.testing.assert_array_equal(
233
  np.sort(selected_X, axis=1), np.sort(X[:, [2, 3]], axis=1)
234
+ )