import pytorch_lightning as pl
import torch
import numpy as np
import datasets
from transformers import MaxLengthCriteria, StoppingCriteriaList
from transformers.optimization import AdamW
import itertools
# from utils import count_stats, f1_metric, pairwise_meteor
from torchmetrics.text.rouge import ROUGEScore
import torch.nn.functional as F
import torchmetrics
from torchmetrics.classification import F1Score

class SequenceClassificationModule(pl.LightningModule):
  # Instantiate the model
  def __init__(self, tokenizer, model, use_question_stance_approach=True, learning_rate=1e-3):
    super().__init__()
    self.tokenizer = tokenizer
    self.model = model
    self.learning_rate = learning_rate

    self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
    self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)
    self.test_acc = torchmetrics.Accuracy(task="multiclass", num_classes=model.num_labels)

    self.train_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average="macro")
    self.val_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
    self.test_f1 = F1Score(task="multiclass", num_classes=model.num_labels, average=None)
    # self.train_acc = torchmetrics.Accuracy()
    # self.val_acc = torchmetrics.Accuracy()
    # self.test_acc = torchmetrics.Accuracy()

    # self.train_f1 = F1Score(num_classes=4, average="macro")
    # self.val_f1 = F1Score(num_classes=4, average=None)
    # self.test_f1 = F1Score(num_classes=4, average=None)

    self.use_question_stance_approach = use_question_stance_approach

  
  # Do a forward pass through the model
  def forward(self, input_ids, **kwargs):
    return self.model(input_ids, **kwargs)
  
  def configure_optimizers(self):
    optimizer = AdamW(self.parameters(), lr = self.learning_rate)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, x_mask, y = batch

    outputs = self(x, attention_mask=x_mask, labels=y)
    logits = outputs.logits
    loss = outputs.loss

    #cross_entropy = torch.nn.CrossEntropyLoss()
    #loss = cross_entropy(logits, y)
    
    preds = torch.argmax(logits, axis=1)

    self.log("train_loss", loss)

    return {'loss': loss}

  def validation_step(self, batch, batch_idx):
    x, x_mask, y = batch

    outputs = self(x, attention_mask=x_mask, labels=y)
    logits = outputs.logits
    loss = outputs.loss

    preds = torch.argmax(logits, axis=1)

    if not self.use_question_stance_approach:
      self.val_acc(preds, y)
      self.log('val_acc_step', self.val_acc)

      self.val_f1(preds, y)      
      self.log("val_loss", loss)

    return {'val_loss':loss, "src": x, "pred": preds, "target": y}

  def validation_epoch_end(self, outs):
    if self.use_question_stance_approach:
      self.handle_end_of_epoch_scoring(outs, self.val_acc, self.val_f1)
    
    self.log('val_acc_epoch', self.val_acc)

    f1 = self.val_f1.compute()
    self.val_f1.reset()

    self.log('val_f1_epoch', torch.mean(f1))

    class_names = ["supported", "refuted", "nei", "conflicting"]
    for i, c_name in enumerate(class_names):
      self.log("val_f1_" + c_name, f1[i])


  def test_step(self, batch, batch_idx):
    x, x_mask, y = batch
    
    outputs = self(x, attention_mask=x_mask)
    logits = outputs.logits

    preds = torch.argmax(logits, axis=1)
    
    if not self.use_question_stance_approach:
      self.test_acc(preds, y)      
      self.log('test_acc_step', self.test_acc)
      self.test_f1(preds, y)      

    return {"src": x, "pred": preds, "target": y}

  def test_epoch_end(self, outs):
    if self.use_question_stance_approach:
      self.handle_end_of_epoch_scoring(outs, self.test_acc, self.test_f1)
    
    self.log('test_acc_epoch', self.test_acc)

    f1 = self.test_f1.compute()
    self.test_f1.reset()
    self.log('test_f1_epoch', torch.mean(f1))

    class_names = ["supported", "refuted", "nei", "conflicting"]
    for i, c_name in enumerate(class_names):
      self.log("test_f1_" + c_name, f1[i])

  def handle_end_of_epoch_scoring(self, outputs, acc_scorer, f1_scorer):
      gold_labels = {}
      question_support = {}
      for out in outputs:
        srcs = out['src']
        preds = out['pred']
        tgts = out['target']

        tokens = self.tokenizer.batch_decode(
          srcs, 
          skip_special_tokens=True, 
          clean_up_tokenization_spaces=True
        )

        for src, pred, tgt in zip(tokens, preds, tgts):
          claim_id = src.split("[ question ]")[0]

          if claim_id not in gold_labels:
            gold_labels[claim_id] = tgt
            question_support[claim_id] = []

          question_support[claim_id].append(pred)

      for k,gold_label in gold_labels.items():
        support = question_support[k]

        has_unansw = False
        has_true = False
        has_false = False

        for v in support:
          if v == 0:
            has_true = True
          if v == 1:
            has_false = True
          if v == 2 or v == 3: # TODO very ugly hack -- we cant have different numbers of labels for train and test so we do this
            has_unansw = True

        if has_unansw:
          answer = 2
        elif has_true and not has_false:
          answer = 0
        elif has_false and not has_true:
          answer = 1
        elif has_true and has_false:
          answer = 3


        # TODO this is very hacky and wont work if the device is literally anything other than cuda:0
        acc_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))    
        f1_scorer(torch.as_tensor([answer]).to("cuda:0"), torch.as_tensor([gold_label]).to("cuda:0"))