from transformers import BertConfig, BertModel
import torch.nn as nn
import torch
from typing import Optional, Union, Tuple, List
from transformers.modeling_outputs import SequenceClassifierOutput
from torch.nn import CrossEntropyLoss


class HydraConfig(BertConfig):
    model_type = "hydra"
    label_groups = None

    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def num_labels(self):
        return sum([len(group) for group in self.label_groups])

    def distilbert_config(self):
        return BertConfig(**self.__dict__)


class HydraSequenceClassifierOutput(SequenceClassifierOutput):
    classifications: List[dict]

    def __init__(self, classifications=None, **kwargs):
        super().__init__(**kwargs)
        self.classifications = classifications


class Hydra(BertModel):
    config_class = HydraConfig

    def __init__(self, config: HydraConfig):
        super().__init__(config)
        self.config = config
        self.pre_classifier = nn.Linear(config.hidden_size, config.hidden_size)
        self.classifiers = nn.Linear(config.hidden_size, sum(
            [len(group) for group in config.label_groups]))
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

        self.embeddings.requires_grad_(False)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[SequenceClassifierOutput, Tuple[torch.Tensor, ...]]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        distilbert_output = super().forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        hidden_state = distilbert_output[0]  # (bs, seq_len, dim)
        pooled_output = hidden_state[:, 0]  # (bs, dim)
        pooled_output = self.pre_classifier(pooled_output)  # (bs, dim)
        pooled_output = nn.ReLU()(pooled_output)  # (bs, dim)
        pooled_output = self.dropout(pooled_output)  # (bs, dim)
        logits = self.classifiers(pooled_output)  # (bs, num_labels)

        loss = None
        if labels is not None:

            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits, labels)

            if not return_dict:
                output = (logits,) + distilbert_output[1:]
                return ((loss,) + output) if loss is not None else output

        classifications = []
        if logits.shape[0] == 1:
            offset = 0
            for group in self.config.label_groups:
                inverted = {group[pair]: pair for pair in group}
                softmax = nn.Softmax(dim=1)
                output = softmax(logits[:, offset:offset + len(group)])
                classification = []
                for i, val in enumerate(output[0]):
                    classification.append((inverted[i], val.item()))
                classification.sort(key=lambda x: x[1], reverse=True)
                classifications.append(classification)
                offset += len(group)

        return HydraSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=distilbert_output.hidden_states,
            attentions=distilbert_output.attentions,
            classifications=classifications
        )

    def to(self, device):
        super().to(device)
        self.pre_classifier.to(device)
        self.classifiers.to(device)
        self.dropout.to(device)
        return self