strexp / lime /tests /test_lime_text.py
markytools's picture
added strexp
d61b9c7
import re
import unittest
import sklearn # noqa
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import f1_score
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline
import numpy as np
from lime.lime_text import LimeTextExplainer
from lime.lime_text import IndexedCharacters, IndexedString
class TestLimeText(unittest.TestCase):
def test_lime_text_explainer_good_regressor(self):
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train',
categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test',
categories=categories)
class_names = ['atheism', 'christian']
vectorizer = TfidfVectorizer(lowercase=False)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)
nb = MultinomialNB(alpha=.01)
nb.fit(train_vectors, newsgroups_train.target)
pred = nb.predict(test_vectors)
f1_score(newsgroups_test.target, pred, average='weighted')
c = make_pipeline(vectorizer, nb)
explainer = LimeTextExplainer(class_names=class_names)
idx = 83
exp = explainer.explain_instance(newsgroups_test.data[idx],
c.predict_proba, num_features=6)
self.assertIsNotNone(exp)
self.assertEqual(6, len(exp.as_list()))
def test_lime_text_tabular_equal_random_state(self):
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train',
categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test',
categories=categories)
class_names = ['atheism', 'christian']
vectorizer = TfidfVectorizer(lowercase=False)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)
nb = MultinomialNB(alpha=.01)
nb.fit(train_vectors, newsgroups_train.target)
pred = nb.predict(test_vectors)
f1_score(newsgroups_test.target, pred, average='weighted')
c = make_pipeline(vectorizer, nb)
explainer = LimeTextExplainer(class_names=class_names, random_state=10)
exp_1 = explainer.explain_instance(newsgroups_test.data[83],
c.predict_proba, num_features=6)
explainer = LimeTextExplainer(class_names=class_names, random_state=10)
exp_2 = explainer.explain_instance(newsgroups_test.data[83],
c.predict_proba, num_features=6)
self.assertTrue(exp_1.as_map() == exp_2.as_map())
def test_lime_text_tabular_not_equal_random_state(self):
categories = ['alt.atheism', 'soc.religion.christian']
newsgroups_train = fetch_20newsgroups(subset='train',
categories=categories)
newsgroups_test = fetch_20newsgroups(subset='test',
categories=categories)
class_names = ['atheism', 'christian']
vectorizer = TfidfVectorizer(lowercase=False)
train_vectors = vectorizer.fit_transform(newsgroups_train.data)
test_vectors = vectorizer.transform(newsgroups_test.data)
nb = MultinomialNB(alpha=.01)
nb.fit(train_vectors, newsgroups_train.target)
pred = nb.predict(test_vectors)
f1_score(newsgroups_test.target, pred, average='weighted')
c = make_pipeline(vectorizer, nb)
explainer = LimeTextExplainer(
class_names=class_names, random_state=10)
exp_1 = explainer.explain_instance(newsgroups_test.data[83],
c.predict_proba, num_features=6)
explainer = LimeTextExplainer(
class_names=class_names, random_state=20)
exp_2 = explainer.explain_instance(newsgroups_test.data[83],
c.predict_proba, num_features=6)
self.assertFalse(exp_1.as_map() == exp_2.as_map())
def test_indexed_characters_bow(self):
s = 'Please, take your time'
inverse_vocab = ['P', 'l', 'e', 'a', 's', ',', ' ', 't', 'k', 'y', 'o', 'u', 'r', 'i', 'm']
positions = [[0], [1], [2, 5, 11, 21], [3, 9],
[4], [6], [7, 12, 17], [8, 18], [10],
[13], [14], [15], [16], [19], [20]]
ic = IndexedCharacters(s)
self.assertTrue(np.array_equal(ic.as_np, np.array(list(s))))
self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s))))
self.assertTrue(ic.inverse_vocab == inverse_vocab)
self.assertTrue(ic.positions == positions)
def test_indexed_characters_not_bow(self):
s = 'Please, take your time'
ic = IndexedCharacters(s, bow=False)
self.assertTrue(np.array_equal(ic.as_np, np.array(list(s))))
self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s))))
self.assertTrue(ic.inverse_vocab == list(s))
self.assertTrue(np.array_equal(ic.positions, np.arange(len(s))))
def test_indexed_string_regex(self):
s = 'Please, take your time. Please'
tokenized_string = np.array(
['Please', ', ', 'take', ' ', 'your', ' ', 'time', '. ', 'Please'])
inverse_vocab = ['Please', 'take', 'your', 'time']
start_positions = [0, 6, 8, 12, 13, 17, 18, 22, 24]
positions = [[0, 8], [2], [4], [6]]
indexed_string = IndexedString(s)
self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string))
self.assertTrue(np.array_equal(indexed_string.string_start, start_positions))
self.assertTrue(indexed_string.inverse_vocab == inverse_vocab)
self.assertTrue(np.array_equal(indexed_string.positions, positions))
def test_indexed_string_callable(self):
s = 'aabbccddaa'
def tokenizer(string):
return [string[i] + string[i + 1] for i in range(0, len(string) - 1, 2)]
tokenized_string = np.array(['aa', 'bb', 'cc', 'dd', 'aa'])
inverse_vocab = ['aa', 'bb', 'cc', 'dd']
start_positions = [0, 2, 4, 6, 8]
positions = [[0, 4], [1], [2], [3]]
indexed_string = IndexedString(s, tokenizer)
self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string))
self.assertTrue(np.array_equal(indexed_string.string_start, start_positions))
self.assertTrue(indexed_string.inverse_vocab == inverse_vocab)
self.assertTrue(np.array_equal(indexed_string.positions, positions))
def test_indexed_string_inverse_removing_tokenizer(self):
s = 'This is a good movie. This, it is a great movie.'
def tokenizer(string):
return re.split(r'(?:\W+)|$', string)
indexed_string = IndexedString(s, tokenizer)
self.assertEqual(s, indexed_string.inverse_removing([]))
def test_indexed_string_inverse_removing_regex(self):
s = 'This is a good movie. This is a great movie'
indexed_string = IndexedString(s)
self.assertEqual(s, indexed_string.inverse_removing([]))
if __name__ == '__main__':
unittest.main()