Spaces:
Build error
Build error
import unittest | |
from unittest import TestCase | |
import numpy as np | |
from sklearn.datasets import load_iris | |
from lime.discretize import QuartileDiscretizer, DecileDiscretizer, EntropyDiscretizer | |
class TestDiscretize(TestCase): | |
def setUp(self): | |
iris = load_iris() | |
self.feature_names = iris.feature_names | |
self.x = iris.data | |
self.y = iris.target | |
def check_random_state_for_discretizer_class(self, DiscretizerClass): | |
# ---------------------------------------------------------------------- | |
# -----------Check if the same random_state produces the same----------- | |
# -------------results for different discretizer instances.------------- | |
# ---------------------------------------------------------------------- | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=10) | |
x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=10) | |
x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1]) | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=np.random.RandomState(10)) | |
x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=np.random.RandomState(10)) | |
x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
self.assertEqual((x_1 == x_2).sum(), x_1.shape[0] * x_1.shape[1]) | |
# ---------------------------------------------------------------------- | |
# ---------Check if two different random_state values produces---------- | |
# -------different results for different discretizers instances.-------- | |
# ---------------------------------------------------------------------- | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=10) | |
x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=20) | |
x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1]) | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=np.random.RandomState(10)) | |
x_1 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
discretizer = DiscretizerClass(self.x, [], self.feature_names, self.y, | |
random_state=np.random.RandomState(20)) | |
x_2 = discretizer.undiscretize(discretizer.discretize(self.x)) | |
self.assertFalse((x_1 == x_2).sum() == x_1.shape[0] * x_1.shape[1]) | |
def test_random_state(self): | |
self.check_random_state_for_discretizer_class(QuartileDiscretizer) | |
self.check_random_state_for_discretizer_class(DecileDiscretizer) | |
self.check_random_state_for_discretizer_class(EntropyDiscretizer) | |
def test_feature_names_1(self): | |
self.maxDiff = None | |
discretizer = QuartileDiscretizer(self.x, [], self.feature_names, | |
self.y, random_state=10) | |
self.assertDictEqual( | |
{0: ['sepal length (cm) <= 5.10', | |
'5.10 < sepal length (cm) <= 5.80', | |
'5.80 < sepal length (cm) <= 6.40', | |
'sepal length (cm) > 6.40'], | |
1: ['sepal width (cm) <= 2.80', | |
'2.80 < sepal width (cm) <= 3.00', | |
'3.00 < sepal width (cm) <= 3.30', | |
'sepal width (cm) > 3.30'], | |
2: ['petal length (cm) <= 1.60', | |
'1.60 < petal length (cm) <= 4.35', | |
'4.35 < petal length (cm) <= 5.10', | |
'petal length (cm) > 5.10'], | |
3: ['petal width (cm) <= 0.30', | |
'0.30 < petal width (cm) <= 1.30', | |
'1.30 < petal width (cm) <= 1.80', | |
'petal width (cm) > 1.80']}, | |
discretizer.names) | |
def test_feature_names_2(self): | |
self.maxDiff = None | |
discretizer = DecileDiscretizer(self.x, [], self.feature_names, self.y, | |
random_state=10) | |
self.assertDictEqual( | |
{0: ['sepal length (cm) <= 4.80', | |
'4.80 < sepal length (cm) <= 5.00', | |
'5.00 < sepal length (cm) <= 5.27', | |
'5.27 < sepal length (cm) <= 5.60', | |
'5.60 < sepal length (cm) <= 5.80', | |
'5.80 < sepal length (cm) <= 6.10', | |
'6.10 < sepal length (cm) <= 6.30', | |
'6.30 < sepal length (cm) <= 6.52', | |
'6.52 < sepal length (cm) <= 6.90', | |
'sepal length (cm) > 6.90'], | |
1: ['sepal width (cm) <= 2.50', | |
'2.50 < sepal width (cm) <= 2.70', | |
'2.70 < sepal width (cm) <= 2.80', | |
'2.80 < sepal width (cm) <= 3.00', | |
'3.00 < sepal width (cm) <= 3.10', | |
'3.10 < sepal width (cm) <= 3.20', | |
'3.20 < sepal width (cm) <= 3.40', | |
'3.40 < sepal width (cm) <= 3.61', | |
'sepal width (cm) > 3.61'], | |
2: ['petal length (cm) <= 1.40', | |
'1.40 < petal length (cm) <= 1.50', | |
'1.50 < petal length (cm) <= 1.70', | |
'1.70 < petal length (cm) <= 3.90', | |
'3.90 < petal length (cm) <= 4.35', | |
'4.35 < petal length (cm) <= 4.64', | |
'4.64 < petal length (cm) <= 5.00', | |
'5.00 < petal length (cm) <= 5.32', | |
'5.32 < petal length (cm) <= 5.80', | |
'petal length (cm) > 5.80'], | |
3: ['petal width (cm) <= 0.20', | |
'0.20 < petal width (cm) <= 0.40', | |
'0.40 < petal width (cm) <= 1.16', | |
'1.16 < petal width (cm) <= 1.30', | |
'1.30 < petal width (cm) <= 1.50', | |
'1.50 < petal width (cm) <= 1.80', | |
'1.80 < petal width (cm) <= 1.90', | |
'1.90 < petal width (cm) <= 2.20', | |
'petal width (cm) > 2.20']}, | |
discretizer.names) | |
def test_feature_names_3(self): | |
self.maxDiff = None | |
discretizer = EntropyDiscretizer(self.x, [], self.feature_names, | |
self.y, random_state=10) | |
self.assertDictEqual( | |
{0: ['sepal length (cm) <= 4.85', | |
'4.85 < sepal length (cm) <= 5.45', | |
'5.45 < sepal length (cm) <= 5.55', | |
'5.55 < sepal length (cm) <= 5.85', | |
'5.85 < sepal length (cm) <= 6.15', | |
'6.15 < sepal length (cm) <= 7.05', | |
'sepal length (cm) > 7.05'], | |
1: ['sepal width (cm) <= 2.45', | |
'2.45 < sepal width (cm) <= 2.95', | |
'2.95 < sepal width (cm) <= 3.05', | |
'3.05 < sepal width (cm) <= 3.35', | |
'3.35 < sepal width (cm) <= 3.45', | |
'3.45 < sepal width (cm) <= 3.55', | |
'sepal width (cm) > 3.55'], | |
2: ['petal length (cm) <= 2.45', | |
'2.45 < petal length (cm) <= 4.45', | |
'4.45 < petal length (cm) <= 4.75', | |
'4.75 < petal length (cm) <= 5.15', | |
'petal length (cm) > 5.15'], | |
3: ['petal width (cm) <= 0.80', | |
'0.80 < petal width (cm) <= 1.35', | |
'1.35 < petal width (cm) <= 1.75', | |
'1.75 < petal width (cm) <= 1.85', | |
'petal width (cm) > 1.85']}, | |
discretizer.names) | |
if __name__ == '__main__': | |
unittest.main() | |