import re
from contextlib import contextmanager

import numpy as np
import torch
import torch.nn.functional as F
from fuzzysearch import find_near_matches
from pyarabic import araby
from torch import nn
from transformers import AutoTokenizer, BertModel, BertPreTrainedModel, pipeline
from transformers.modeling_outputs import SequenceClassifierOutput

from .preprocess import ArabertPreprocessor, url_regexes, user_mention_regex

multiple_char_pattern = re.compile(r"(.)\1{2,}", re.DOTALL)

# ASAD-NEW_AraBERT_PREP-Balanced
class NewArabicPreprocessorBalanced(ArabertPreprocessor):
    def __init__(
        self,
        model_name: str,
        keep_emojis: bool = False,
        remove_html_markup: bool = True,
        replace_urls_emails_mentions: bool = True,
        strip_tashkeel: bool = True,
        strip_tatweel: bool = True,
        insert_white_spaces: bool = True,
        remove_non_digit_repetition: bool = True,
        replace_slash_with_dash: bool = None,
        map_hindi_numbers_to_arabic: bool = None,
        apply_farasa_segmentation: bool = None,
    ):
        if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
            keep_emojis = True
            remove_non_digit_repetition = True
        super().__init__(
            model_name=model_name,
            keep_emojis=keep_emojis,
            remove_html_markup=remove_html_markup,
            replace_urls_emails_mentions=replace_urls_emails_mentions,
            strip_tashkeel=strip_tashkeel,
            strip_tatweel=strip_tatweel,
            insert_white_spaces=insert_white_spaces,
            remove_non_digit_repetition=remove_non_digit_repetition,
            replace_slash_with_dash=replace_slash_with_dash,
            map_hindi_numbers_to_arabic=map_hindi_numbers_to_arabic,
            apply_farasa_segmentation=apply_farasa_segmentation,
        )
        self.true_model_name = model_name

    def preprocess(self, text):
        if "UBC-NLP" in self.true_model_name:
            return self.ubc_prep(text)

    def ubc_prep(self, text):
        text = re.sub("\s", " ", text)
        text = text.replace("\\n", " ")
        text = text.replace("\\r", " ")
        text = araby.strip_tashkeel(text)
        text = araby.strip_tatweel(text)
        # replace all possible URLs
        for reg in url_regexes:
            text = re.sub(reg, " URL ", text)
        text = re.sub("(URL\s*)+", " URL ", text)
        # replace mentions with USER
        text = re.sub(user_mention_regex, " USER ", text)
        text = re.sub("(USER\s*)+", " USER ", text)
        # replace hashtags with HASHTAG
        # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
        text = text.replace("#", " HASH ")
        text = text.replace("_", " ")
        text = " ".join(text.split())
        # text = re.sub("\B\\[Uu]\w+", "", text)
        text = text.replace("\\U0001f97a", "๐Ÿฅบ")
        text = text.replace("\\U0001f928", "๐Ÿคจ")
        text = text.replace("\\U0001f9d8", "๐Ÿ˜€")
        text = text.replace("\\U0001f975", "๐Ÿ˜ฅ")
        text = text.replace("\\U0001f92f", "๐Ÿ˜ฒ")
        text = text.replace("\\U0001f92d", "๐Ÿคญ")
        text = text.replace("\\U0001f9d1", "๐Ÿ˜")
        text = text.replace("\\U000e0067", "")
        text = text.replace("\\U000e006e", "")
        text = text.replace("\\U0001f90d", "โ™ฅ")
        text = text.replace("\\U0001f973", "๐ŸŽ‰")
        text = text.replace("\\U0001fa79", "")
        text = text.replace("\\U0001f92b", "๐Ÿค")
        text = text.replace("\\U0001f9da", "๐Ÿฆ‹")
        text = text.replace("\\U0001f90e", "โ™ฅ")
        text = text.replace("\\U0001f9d0", "๐Ÿง")
        text = text.replace("\\U0001f9cf", "")
        text = text.replace("\\U0001f92c", "๐Ÿ˜ ")
        text = text.replace("\\U0001f9f8", "๐Ÿ˜ธ")
        text = text.replace("\\U0001f9b6", "๐Ÿ’ฉ")
        text = text.replace("\\U0001f932", "๐Ÿคฒ")
        text = text.replace("\\U0001f9e1", "๐Ÿงก")
        text = text.replace("\\U0001f974", "โ˜น")
        text = text.replace("\\U0001f91f", "")
        text = text.replace("\\U0001f9fb", "๐Ÿ’ฉ")
        text = text.replace("\\U0001f92a", "๐Ÿคช")
        text = text.replace("\\U0001f9fc", "")
        text = text.replace("\\U000e0065", "")
        text = text.replace("\\U0001f92e", "๐Ÿ’ฉ")
        text = text.replace("\\U000e007f", "")
        text = text.replace("\\U0001f970", "๐Ÿฅฐ")
        text = text.replace("\\U0001f929", "๐Ÿคฉ")
        text = text.replace("\\U0001f6f9", "")
        text = text.replace("๐Ÿค", "โ™ฅ")
        text = text.replace("๐Ÿฆ ", "๐Ÿ˜ท")
        text = text.replace("๐Ÿคข", "ู…ู‚ุฑู")
        text = text.replace("๐Ÿคฎ", "ู…ู‚ุฑู")
        text = text.replace("๐Ÿ• ", "โŒš")
        text = text.replace("๐Ÿคฌ", "๐Ÿ˜ ")
        text = text.replace("๐Ÿคง", "๐Ÿ˜ท")
        text = text.replace("๐Ÿฅณ", "๐ŸŽ‰")
        text = text.replace("๐Ÿฅต", "๐Ÿ”ฅ")
        text = text.replace("๐Ÿฅด", "โ˜น")
        text = text.replace("๐Ÿคซ", "๐Ÿค")
        text = text.replace("๐Ÿคฅ", "ูƒุฐุงุจ")
        text = text.replace("\\u200d", " ")
        text = text.replace("u200d", " ")
        text = text.replace("\\u200c", " ")
        text = text.replace("u200c", " ")
        text = text.replace('"', "'")
        text = text.replace("\\xa0", "")
        text = text.replace("\\u2066", " ")
        text = re.sub("\B\\\[Uu]\w+", "", text)
        text = super(NewArabicPreprocessorBalanced, self).preprocess(text)

        text = " ".join(text.split())
        return text


"""CNNMarbertArabicPreprocessor"""
# ASAD-CNN_MARBERT
class CNNMarbertArabicPreprocessor(ArabertPreprocessor):
    def __init__(
        self,
        model_name,
        keep_emojis=False,
        remove_html_markup=True,
        replace_urls_emails_mentions=True,
        remove_elongations=True,
    ):
        if "UBC-NLP" in model_name or "CAMeL-Lab" in model_name:
            keep_emojis = True
            remove_elongations = False
        super().__init__(
            model_name,
            keep_emojis,
            remove_html_markup,
            replace_urls_emails_mentions,
            remove_elongations,
        )
        self.true_model_name = model_name

    def preprocess(self, text):
        if "UBC-NLP" in self.true_model_name:
            return self.ubc_prep(text)

    def ubc_prep(self, text):
        text = re.sub("\s", " ", text)
        text = text.replace("\\n", " ")
        text = araby.strip_tashkeel(text)
        text = araby.strip_tatweel(text)
        # replace all possible URLs
        for reg in url_regexes:
            text = re.sub(reg, " URL ", text)
        text = re.sub("(URL\s*)+", " URL ", text)
        # replace mentions with USER
        text = re.sub(user_mention_regex, " USER ", text)
        text = re.sub("(USER\s*)+", " USER ", text)
        # replace hashtags with HASHTAG
        # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
        text = text.replace("#", " HASH ")
        text = text.replace("_", " ")
        text = " ".join(text.split())
        text = super(CNNMarbertArabicPreprocessor, self).preprocess(text)
        text = text.replace("\u200d", " ")
        text = text.replace("u200d", " ")
        text = text.replace("\u200c", " ")
        text = text.replace("u200c", " ")
        text = text.replace('"', "'")
        # text = re.sub('[\d\.]+', ' NUM ', text)
        # text = re.sub('(NUM\s*)+', ' NUM ', text)
        text = multiple_char_pattern.sub(r"\1\1", text)
        text = " ".join(text.split())
        return text


"""Trial5ArabicPreprocessor"""


class Trial5ArabicPreprocessor(ArabertPreprocessor):
    def __init__(
        self,
        model_name,
        keep_emojis=False,
        remove_html_markup=True,
        replace_urls_emails_mentions=True,
    ):
        if "UBC-NLP" in model_name:
            keep_emojis = True
        super().__init__(
            model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
        )
        self.true_model_name = model_name

    def preprocess(self, text):
        if "UBC-NLP" in self.true_model_name:
            return self.ubc_prep(text)

    def ubc_prep(self, text):
        text = re.sub("\s", " ", text)
        text = text.replace("\\n", " ")
        text = araby.strip_tashkeel(text)
        text = araby.strip_tatweel(text)
        # replace all possible URLs
        for reg in url_regexes:
            text = re.sub(reg, " URL ", text)
        # replace mentions with USER
        text = re.sub(user_mention_regex, " USER ", text)
        # replace hashtags with HASHTAG
        # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
        text = text.replace("#", " HASH TAG ")
        text = text.replace("_", " ")
        text = " ".join(text.split())
        text = super(Trial5ArabicPreprocessor, self).preprocess(text)
        # text = text.replace("ุงู„ุณู„ุงู… ุนู„ูŠูƒู…"," ")
        # text = text.replace(find_near_matches("ุงู„ุณู„ุงู… ุนู„ูŠูƒู…",text,max_deletions=3,max_l_dist=3)[0].matched," ")
        return text


"""SarcasmArabicPreprocessor"""


class SarcasmArabicPreprocessor(ArabertPreprocessor):
    def __init__(
        self,
        model_name,
        keep_emojis=False,
        remove_html_markup=True,
        replace_urls_emails_mentions=True,
    ):
        if "UBC-NLP" in model_name:
            keep_emojis = True
        super().__init__(
            model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
        )
        self.true_model_name = model_name

    def preprocess(self, text):
        if "UBC-NLP" in self.true_model_name:
            return self.ubc_prep(text)
        else:
            return super(SarcasmArabicPreprocessor, self).preprocess(text)

    def ubc_prep(self, text):
        text = re.sub("\s", " ", text)
        text = text.replace("\\n", " ")
        text = araby.strip_tashkeel(text)
        text = araby.strip_tatweel(text)
        # replace all possible URLs
        for reg in url_regexes:
            text = re.sub(reg, " URL ", text)
        # replace mentions with USER
        text = re.sub(user_mention_regex, " USER ", text)
        # replace hashtags with HASHTAG
        # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
        text = text.replace("#", " HASH TAG ")
        text = text.replace("_", " ")
        text = text.replace('"', " ")
        text = " ".join(text.split())
        text = super(SarcasmArabicPreprocessor, self).preprocess(text)
        return text


"""NoAOAArabicPreprocessor"""


class NoAOAArabicPreprocessor(ArabertPreprocessor):
    def __init__(
        self,
        model_name,
        keep_emojis=False,
        remove_html_markup=True,
        replace_urls_emails_mentions=True,
    ):
        if "UBC-NLP" in model_name:
            keep_emojis = True
        super().__init__(
            model_name, keep_emojis, remove_html_markup, replace_urls_emails_mentions
        )
        self.true_model_name = model_name

    def preprocess(self, text):
        if "UBC-NLP" in self.true_model_name:
            return self.ubc_prep(text)
        else:
            return super(NoAOAArabicPreprocessor, self).preprocess(text)

    def ubc_prep(self, text):
        text = re.sub("\s", " ", text)
        text = text.replace("\\n", " ")
        text = araby.strip_tashkeel(text)
        text = araby.strip_tatweel(text)
        # replace all possible URLs
        for reg in url_regexes:
            text = re.sub(reg, " URL ", text)
        # replace mentions with USER
        text = re.sub(user_mention_regex, " USER ", text)
        # replace hashtags with HASHTAG
        # text = re.sub(r"#[\w\d]+", " HASH TAG ", text)
        text = text.replace("#", " HASH TAG ")
        text = text.replace("_", " ")
        text = " ".join(text.split())
        text = super(NoAOAArabicPreprocessor, self).preprocess(text)
        text = text.replace("ุงู„ุณู„ุงู… ุนู„ูŠูƒู…", " ")
        text = text.replace("ูˆุฑุญู…ุฉ ุงู„ู„ู‡ ูˆุจุฑูƒุงุชู‡", " ")
        matched = find_near_matches("ุงู„ุณู„ุงู… ุนู„ูŠูƒู…", text, max_deletions=3, max_l_dist=3)
        if len(matched) > 0:
            text = text.replace(matched[0].matched, " ")
        matched = find_near_matches(
            "ูˆุฑุญู…ุฉ ุงู„ู„ู‡ ูˆุจุฑูƒุงุชู‡", text, max_deletions=3, max_l_dist=3
        )
        if len(matched) > 0:
            text = text.replace(matched[0].matched, " ")
        return text


class CnnBertForSequenceClassification(BertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.bert = BertModel(config)

        filter_sizes = [1, 2, 3, 4, 5]
        num_filters = 32
        self.convs1 = nn.ModuleList(
            [nn.Conv2d(4, num_filters, (K, config.hidden_size)) for K in filter_sizes]
        )
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(len(filter_sizes) * num_filters, config.num_labels)

        self.init_weights()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.bert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        x = outputs[2][-4:]

        x = torch.stack(x, dim=1)
        x = [F.relu(conv(x)).squeeze(3) for conv in self.convs1]
        x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x]
        x = torch.cat(x, 1)
        x = self.dropout(x)
        logits = self.classifier(x)

        loss = None
        if labels is not None:
            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (
                    labels.dtype == torch.long or labels.dtype == torch.int
                ):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = nn.MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = nn.CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = nn.BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)
        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=outputs.attentions,
        )


class CNNTextClassificationPipeline:
    def __init__(self, model_path, device, return_all_scores=False):
        self.model_path = model_path
        self.model = CnnBertForSequenceClassification.from_pretrained(self.model_path)
        # Special handling
        self.device = torch.device("cpu" if device < 0 else f"cuda:{device}")
        if self.device.type == "cuda":
            self.model = self.model.to(self.device)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.return_all_scores = return_all_scores

    @contextmanager
    def device_placement(self):
        """
        Context Manager allowing tensor allocation on the user-specified device in framework agnostic way.
        Returns:
            Context manager
        Examples::
            # Explicitly ask for tensor allocation on CUDA device :0
            pipe = pipeline(..., device=0)
            with pipe.device_placement():
                # Every framework specific tensor allocation will be done on the request device
                output = pipe(...)
        """

        if self.device.type == "cuda":
            torch.cuda.set_device(self.device)

        yield

    def ensure_tensor_on_device(self, **inputs):
        """
        Ensure PyTorch tensors are on the specified device.
        Args:
            inputs (keyword arguments that should be :obj:`torch.Tensor`): The tensors to place on :obj:`self.device`.
        Return:
            :obj:`Dict[str, torch.Tensor]`: The same as :obj:`inputs` but on the proper device.
        """
        return {
            name: tensor.to(self.device) if isinstance(tensor, torch.Tensor) else tensor
            for name, tensor in inputs.items()
        }

    def __call__(self, text):
        """
        Classify the text(s) given as inputs.
        Args:
            args (:obj:`str` or :obj:`List[str]`):
                One or several texts (or one list of prompts) to classify.
        Return:
            A list or a list of list of :obj:`dict`: Each result comes as list of dictionaries with the following keys:
            - **label** (:obj:`str`) -- The label predicted.
            - **score** (:obj:`float`) -- The corresponding probability.
            If ``self.return_all_scores=True``, one such dictionary is returned per label.
        """
        # outputs = super().__call__(*args, **kwargs)
        inputs = self.tokenizer.batch_encode_plus(
            text,
            add_special_tokens=True,
            max_length=64,
            padding=True,
            truncation="longest_first",
            return_tensors="pt",
        )

        with torch.no_grad():
            inputs = self.ensure_tensor_on_device(**inputs)
            predictions = self.model(**inputs)[0].cpu()

        predictions = predictions.numpy()

        if self.model.config.num_labels == 1:
            scores = 1.0 / (1.0 + np.exp(-predictions))
        else:
            scores = np.exp(predictions) / np.exp(predictions).sum(-1, keepdims=True)
        if self.return_all_scores:
            return [
                [
                    {"label": self.model.config.id2label[i], "score": score.item()}
                    for i, score in enumerate(item)
                ]
                for item in scores
            ]
        else:
            return [
                {"label": self.inv_label_map[item.argmax()], "score": item.max().item()}
                for item in scores
            ]