MilesCranmer commited on
Commit
22eb380
1 Parent(s): 2621e9c

Add test for pickling + units

Browse files
Files changed (1) hide show
  1. pysr/test/test.py +25 -1
pysr/test/test.py CHANGED
@@ -1010,10 +1010,15 @@ class TestDimensionalConstraints(unittest.TestCase):
1010
  """Check that units are propagated correctly."""
1011
  X = np.ones((100, 3))
1012
  y = np.ones((100, 1))
 
 
1013
  model = PySRRegressor(
1014
  binary_operators=["+", "*"],
1015
  early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
1016
- **self.default_test_kwargs,
 
 
 
1017
  complexity_of_constants=10,
1018
  weight_mutate_constant=0.0,
1019
  should_optimize_constants=False,
@@ -1021,6 +1026,7 @@ class TestDimensionalConstraints(unittest.TestCase):
1021
  deterministic=True,
1022
  procs=0,
1023
  random_state=0,
 
1024
  )
1025
  model.fit(
1026
  X,
@@ -1034,6 +1040,24 @@ class TestDimensionalConstraints(unittest.TestCase):
1034
  self.assertIn("x2", best["equation"])
1035
  self.assertEqual(best["complexity"], 3)
1036
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1037
 
1038
  def runtests():
1039
  """Run all tests in test.py."""
 
1010
  """Check that units are propagated correctly."""
1011
  X = np.ones((100, 3))
1012
  y = np.ones((100, 1))
1013
+ temp_dir = Path(tempfile.mkdtemp())
1014
+ equation_file = str(temp_dir / "equation_file.csv")
1015
  model = PySRRegressor(
1016
  binary_operators=["+", "*"],
1017
  early_stop_condition="(l, c) -> l < 1e-8 && c == 3",
1018
+ progress=False,
1019
+ model_selection="accuracy",
1020
+ niterations=DEFAULT_NITERATIONS * 2,
1021
+ populations=DEFAULT_POPULATIONS * 2,
1022
  complexity_of_constants=10,
1023
  weight_mutate_constant=0.0,
1024
  should_optimize_constants=False,
 
1026
  deterministic=True,
1027
  procs=0,
1028
  random_state=0,
1029
+ equation_file=equation_file,
1030
  )
1031
  model.fit(
1032
  X,
 
1040
  self.assertIn("x2", best["equation"])
1041
  self.assertEqual(best["complexity"], 3)
1042
 
1043
+ # With pkl file:
1044
+ pkl_file = str(temp_dir / "equation_file.pkl")
1045
+ model2 = PySRRegressor.from_file(pkl_file)
1046
+ best2 = model2.get_best()
1047
+ self.assertIn("x0", best2["equation"])
1048
+
1049
+ # From csv file alone (we need to delete pkl file:)
1050
+ # First, we delete the pkl file:
1051
+ os.remove(pkl_file)
1052
+ model3 = PySRRegressor.from_file(
1053
+ equation_file, binary_operators=["+", "*"], n_features_in=X.shape[1]
1054
+ )
1055
+ best3 = model3.get_best()
1056
+ self.assertIn("x0", best3["equation"])
1057
+
1058
+
1059
+ # TODO: Determine desired behavior if second .fit() call does not have units
1060
+
1061
 
1062
  def runtests():
1063
  """Run all tests in test.py."""