MilesCranmer commited on
Commit
7a792a8
1 Parent(s): c4f15ef

Add test of other multioutput indexes

Browse files
Files changed (1) hide show
  1. test/test.py +4 -0
test/test.py CHANGED
@@ -61,6 +61,10 @@ class TestPipeline(unittest.TestCase):
61
  self.assertLessEqual(mse1, 1e-4)
62
  self.assertLessEqual(mse2, 1e-4)
63
 
 
 
 
 
64
  def test_multioutput_weighted_with_callable_temp_equation(self):
65
  y = self.X[:, [0, 1]] ** 2
66
  w = np.random.rand(*y.shape)
 
61
  self.assertLessEqual(mse1, 1e-4)
62
  self.assertLessEqual(mse2, 1e-4)
63
 
64
+ bad_y = model.predict(self.X, index=[0, 0])
65
+ bad_mse = np.average((bad_y - y) ** 2)
66
+ self.assertGreater(bad_mse, 1e-4)
67
+
68
  def test_multioutput_weighted_with_callable_temp_equation(self):
69
  y = self.X[:, [0, 1]] ** 2
70
  w = np.random.rand(*y.shape)