tttc3 commited on
Commit
3821242
1 Parent(s): fbb7cf7

Cleaned test and docstring

Browse files
Files changed (2) hide show
  1. pysr/sr.py +0 -6
  2. test/test_torch.py +10 -9
pysr/sr.py CHANGED
@@ -1427,12 +1427,6 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
1427
  If :param`X` is a pandas dataframe, the column names will be used.
1428
  If variable_names are specified
1429
 
1430
- from_equation_file : bool, default=False
1431
- Allows model to be initialized/fit from a previous run that has
1432
- been saved to a file. If true, a value of y still needs to be
1433
- passed such that `nout_` can be determined, however, the values of
1434
- y are irrelevant and can be all zeros.
1435
-
1436
  Returns
1437
  -------
1438
  self : object
 
1427
  If :param`X` is a pandas dataframe, the column names will be used.
1428
  If variable_names are specified
1429
 
 
 
 
 
 
 
1430
  Returns
1431
  -------
1432
  self : object
test/test_torch.py CHANGED
@@ -23,6 +23,16 @@ class TestTorch(unittest.TestCase):
23
  )
24
 
25
  def test_pipeline_pandas(self):
 
 
 
 
 
 
 
 
 
 
26
  equations = pd.DataFrame(
27
  {
28
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
@@ -35,15 +45,6 @@ class TestTorch(unittest.TestCase):
35
  "equation_file.csv.bkup", sep="|"
36
  )
37
 
38
- X = pd.DataFrame(np.random.randn(100, 10))
39
- y = np.ones(X.shape[0])
40
- model = PySRRegressor(
41
- max_evals=10000,
42
- model_selection="accuracy",
43
- extra_sympy_mappings={},
44
- output_torch_format=True,
45
- )
46
- model.fit(X, y)
47
  model.refresh(checkpoint_file="equation_file.csv")
48
  tformat = model.pytorch()
49
  self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")
 
23
  )
24
 
25
  def test_pipeline_pandas(self):
26
+ X = pd.DataFrame(np.random.randn(100, 10))
27
+ y = np.ones(X.shape[0])
28
+ model = PySRRegressor(
29
+ max_evals=10000,
30
+ model_selection="accuracy",
31
+ extra_sympy_mappings={},
32
+ output_torch_format=True,
33
+ )
34
+ model.fit(X, y)
35
+
36
  equations = pd.DataFrame(
37
  {
38
  "Equation": ["1.0", "cos(x1)", "square(cos(x1))"],
 
45
  "equation_file.csv.bkup", sep="|"
46
  )
47
 
 
 
 
 
 
 
 
 
 
48
  model.refresh(checkpoint_file="equation_file.csv")
49
  tformat = model.pytorch()
50
  self.assertEqual(str(tformat), "_SingleSymPyModule(expression=cos(x1)**2)")