import re
import unittest

import sklearn # noqa
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import f1_score
from sklearn.naive_bayes import MultinomialNB
from sklearn.pipeline import make_pipeline

import numpy as np

from lime.lime_text import LimeTextExplainer
from lime.lime_text import IndexedCharacters, IndexedString


class TestLimeText(unittest.TestCase):

    def test_lime_text_explainer_good_regressor(self):
        categories = ['alt.atheism', 'soc.religion.christian']
        newsgroups_train = fetch_20newsgroups(subset='train',
                                              categories=categories)
        newsgroups_test = fetch_20newsgroups(subset='test',
                                             categories=categories)
        class_names = ['atheism', 'christian']
        vectorizer = TfidfVectorizer(lowercase=False)
        train_vectors = vectorizer.fit_transform(newsgroups_train.data)
        test_vectors = vectorizer.transform(newsgroups_test.data)
        nb = MultinomialNB(alpha=.01)
        nb.fit(train_vectors, newsgroups_train.target)
        pred = nb.predict(test_vectors)
        f1_score(newsgroups_test.target, pred, average='weighted')
        c = make_pipeline(vectorizer, nb)
        explainer = LimeTextExplainer(class_names=class_names)
        idx = 83
        exp = explainer.explain_instance(newsgroups_test.data[idx],
                                         c.predict_proba, num_features=6)
        self.assertIsNotNone(exp)
        self.assertEqual(6, len(exp.as_list()))

    def test_lime_text_tabular_equal_random_state(self):
        categories = ['alt.atheism', 'soc.religion.christian']
        newsgroups_train = fetch_20newsgroups(subset='train',
                                              categories=categories)
        newsgroups_test = fetch_20newsgroups(subset='test',
                                             categories=categories)
        class_names = ['atheism', 'christian']
        vectorizer = TfidfVectorizer(lowercase=False)
        train_vectors = vectorizer.fit_transform(newsgroups_train.data)
        test_vectors = vectorizer.transform(newsgroups_test.data)
        nb = MultinomialNB(alpha=.01)
        nb.fit(train_vectors, newsgroups_train.target)
        pred = nb.predict(test_vectors)
        f1_score(newsgroups_test.target, pred, average='weighted')
        c = make_pipeline(vectorizer, nb)

        explainer = LimeTextExplainer(class_names=class_names, random_state=10)
        exp_1 = explainer.explain_instance(newsgroups_test.data[83],
                                           c.predict_proba, num_features=6)

        explainer = LimeTextExplainer(class_names=class_names, random_state=10)
        exp_2 = explainer.explain_instance(newsgroups_test.data[83],
                                           c.predict_proba, num_features=6)

        self.assertTrue(exp_1.as_map() == exp_2.as_map())

    def test_lime_text_tabular_not_equal_random_state(self):
        categories = ['alt.atheism', 'soc.religion.christian']
        newsgroups_train = fetch_20newsgroups(subset='train',
                                              categories=categories)
        newsgroups_test = fetch_20newsgroups(subset='test',
                                             categories=categories)
        class_names = ['atheism', 'christian']
        vectorizer = TfidfVectorizer(lowercase=False)
        train_vectors = vectorizer.fit_transform(newsgroups_train.data)
        test_vectors = vectorizer.transform(newsgroups_test.data)
        nb = MultinomialNB(alpha=.01)
        nb.fit(train_vectors, newsgroups_train.target)
        pred = nb.predict(test_vectors)
        f1_score(newsgroups_test.target, pred, average='weighted')
        c = make_pipeline(vectorizer, nb)

        explainer = LimeTextExplainer(
            class_names=class_names, random_state=10)
        exp_1 = explainer.explain_instance(newsgroups_test.data[83],
                                           c.predict_proba, num_features=6)

        explainer = LimeTextExplainer(
            class_names=class_names, random_state=20)
        exp_2 = explainer.explain_instance(newsgroups_test.data[83],
                                           c.predict_proba, num_features=6)

        self.assertFalse(exp_1.as_map() == exp_2.as_map())

    def test_indexed_characters_bow(self):
        s = 'Please, take your time'
        inverse_vocab = ['P', 'l', 'e', 'a', 's', ',', ' ', 't', 'k', 'y', 'o', 'u', 'r', 'i', 'm']
        positions = [[0], [1], [2, 5, 11, 21], [3, 9],
                     [4], [6], [7, 12, 17], [8, 18], [10],
                     [13], [14], [15], [16], [19], [20]]
        ic = IndexedCharacters(s)

        self.assertTrue(np.array_equal(ic.as_np, np.array(list(s))))
        self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s))))
        self.assertTrue(ic.inverse_vocab == inverse_vocab)
        self.assertTrue(ic.positions == positions)

    def test_indexed_characters_not_bow(self):
        s = 'Please, take your time'

        ic = IndexedCharacters(s, bow=False)

        self.assertTrue(np.array_equal(ic.as_np, np.array(list(s))))
        self.assertTrue(np.array_equal(ic.string_start, np.arange(len(s))))
        self.assertTrue(ic.inverse_vocab == list(s))
        self.assertTrue(np.array_equal(ic.positions, np.arange(len(s))))

    def test_indexed_string_regex(self):
        s = 'Please, take your time. Please'
        tokenized_string = np.array(
            ['Please', ', ', 'take', ' ', 'your', ' ', 'time', '. ', 'Please'])
        inverse_vocab = ['Please', 'take', 'your', 'time']
        start_positions = [0, 6, 8, 12, 13, 17, 18, 22, 24]
        positions = [[0, 8], [2], [4], [6]]
        indexed_string = IndexedString(s)

        self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string))
        self.assertTrue(np.array_equal(indexed_string.string_start, start_positions))
        self.assertTrue(indexed_string.inverse_vocab == inverse_vocab)
        self.assertTrue(np.array_equal(indexed_string.positions, positions))

    def test_indexed_string_callable(self):
        s = 'aabbccddaa'

        def tokenizer(string):
            return [string[i] + string[i + 1] for i in range(0, len(string) - 1, 2)]

        tokenized_string = np.array(['aa', 'bb', 'cc', 'dd', 'aa'])
        inverse_vocab = ['aa', 'bb', 'cc', 'dd']
        start_positions = [0, 2, 4, 6, 8]
        positions = [[0, 4], [1], [2], [3]]
        indexed_string = IndexedString(s, tokenizer)

        self.assertTrue(np.array_equal(indexed_string.as_np, tokenized_string))
        self.assertTrue(np.array_equal(indexed_string.string_start, start_positions))
        self.assertTrue(indexed_string.inverse_vocab == inverse_vocab)
        self.assertTrue(np.array_equal(indexed_string.positions, positions))

    def test_indexed_string_inverse_removing_tokenizer(self):
        s = 'This is a good movie. This, it is a great movie.'

        def tokenizer(string):
            return re.split(r'(?:\W+)|$', string)

        indexed_string = IndexedString(s, tokenizer)

        self.assertEqual(s, indexed_string.inverse_removing([]))

    def test_indexed_string_inverse_removing_regex(self):
        s = 'This is a good movie. This is a great movie'
        indexed_string = IndexedString(s)

        self.assertEqual(s, indexed_string.inverse_removing([]))


if __name__ == '__main__':
    unittest.main()