# -*- coding: utf-8 -*-
"""
Created on Thu May 19 13:22:32 2022

@author: UTKARSH
"""

import glob
import os

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from src.clean import clean_license_text

from tqdm.auto import tqdm

from transformers import (
    AdamW,
    T5ForConditionalGeneration,
    T5TokenizerFast as T5Tokenizer,
    AutoModelForSeq2SeqLM,
    AutoTokenizer
)

from src.read_data import read_license_summary_data


MODEL_PATH = "models/"
MODEL_FILENAME = "t5-base.model"

MODEL_NAME = "t5-base"
TOKENIZER = None

TEXT_MAX_TOKEN_LEN = 512
SUMMARY_MAX_TOKEN_LEN = 128

N_EPOCHS = 1
BATCH_SIZE = 1

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


class LicenseSummaryDataset(Dataset):
    def __init__(
        self,
        data: pd.DataFrame,
        tokenizer: T5Tokenizer,
        text_max_token_len: int=512,
        summary_max_token_len: int=128
    ):
        self.tokenizer = tokenizer
        self.data = data
        self.text_max_token_len = text_max_token_len
        self.summary_max_token_len = summary_max_token_len


    def __len__(self):
        return len(self.data)

    def __getitem__(self, index: int):
        data_row = self.data.iloc[index]
        text = data_row["text"]
        text_encoding = self.tokenizer(
            text,
            max_length=self.text_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )

        summary_encoding = self.tokenizer(
            data_row["summary"],
            max_length=self.summary_max_token_len,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt"
        )

        labels = summary_encoding["input_ids"]
        labels[labels == 0] = -100

        return dict(
            text=text,
            summary=data_row["summary"],
            text_input_ids=text_encoding["input_ids"].flatten(),
            text_attention_mask=text_encoding["attention_mask"].flatten(),
            labels=labels.flatten(),
            labels_attention_mask=summary_encoding["attention_mask"].flatten()
        )


def prepare_dataloaders():
    """
    Helper method to load data and create batched Dataloaders

    Returns
    -------
    train_dataloader : DataLoader
        Train DataLoader.
    dev_dataloader : DataLoader
        Validation DataLoader.

    """
    license_summary_data = pd.DataFrame(read_license_summary_data())

    train_df, dev_df = train_test_split(license_summary_data, test_size=0.1)

    TOKENIZER = T5Tokenizer.from_pretrained(MODEL_NAME)

    train_dataset = LicenseSummaryDataset(
        train_df,
        TOKENIZER,
        TEXT_MAX_TOKEN_LEN,
        SUMMARY_MAX_TOKEN_LEN
    )

    dev_dataset = LicenseSummaryDataset(
        dev_df,
        TOKENIZER,
        TEXT_MAX_TOKEN_LEN,
        SUMMARY_MAX_TOKEN_LEN
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )

    dev_dataloader = DataLoader(
        dev_dataset,
        batch_size=BATCH_SIZE,
        shuffle=True,
        num_workers=0
    )

    return train_dataloader, dev_dataloader


def train(epoch, model, dataloader, optimizer, batch_size):
    """
    Trains the given model on the given data for the given number of epochs.

    Parameters
    ----------
    epoch : int
        The epoch number for which the model is being trained.
    model : Summarizer
        A summarizer model which we train.
    dataloader : torch.utils.data.DataLoader
        The dataloader on which the model is to be trained.
    optimizer : transformers.AdamW
        The optimizer to be used to optimize weights during training.
    batch_size : int
        The size of each batch as set in the dataloader.

    """
    model.train()
    total_train_loss = 0

    for _, batch in tqdm(enumerate(dataloader)):
        model.zero_grad()
        print(_)
        input_ids = batch["text_input_ids"].to(device, dtype=torch.long)
        attention_mask = batch["text_attention_mask"].to(device, dtype=torch.long)
        labels = batch["labels"].to(device, dtype=torch.long)
        labels_attention_mask = batch["labels_attention_mask"].to(device, dtype=torch.long)

        model_output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            decoder_attention_mask=labels_attention_mask,
            labels=labels
        )

        # loss, _ = model_output.loss, model_output.logits
        loss = model_output.loss

        total_train_loss += loss.item()
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

    avg_train_loss = total_train_loss / len(dataloader)

    print(f"Epoch {epoch}: Training loss: {avg_train_loss}")


def train_and_save_model(train_dataloader, PATH):
    """
    Trains a summarizer model from the given Dataloader and saves it at the
    given path

    Parameters
    ----------
    train_dataloader : Dataloader
        Batched Training Dataloader.
    PATH : str
        Path where the trained model is to be saved.

    Returns
    -------
    model : Summarizer / torch.nn.Module
        Trained model.

    """
    model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True).to(device)

    optimizer = AdamW(model.parameters(), lr=3e-5)

    for epoch in range(1, N_EPOCHS + 1):
        train(epoch, model, train_dataloader, optimizer, BATCH_SIZE)

    torch.save(model.state_dict(), PATH)

    print("Model Saved!")

    return model


def summarize_text_with_model(text, model, tokenizer):
    """
    Summarizes License text using the given trained T5 model.

    Parameters
    ----------
    text : str
        The License text to be summarized.
    model : Summarizer / torch.nn.Module
        The trained model which is to be used to summarize text.
    tokenizer : Tokenzier
        The tokenizer used to tokenize text for model.

    Returns
    -------
    str
        Summary of the License text from the given model.
    definitions : str
        Definitions extracted from the License text.

    """
    text, definitions, _ = clean_license_text(text)

    text_encoding = tokenizer(
        text,
        max_length=TEXT_MAX_TOKEN_LEN,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt"
    )

    generated_ids = model.generate(
        input_ids=text_encoding["input_ids"].to(device, dtype=torch.long),
        attention_mask=text_encoding["attention_mask"].to(device, dtype=torch.long),
        max_length=SUMMARY_MAX_TOKEN_LEN,
        num_beams=2,
        repetition_penalty=2.5,
        length_penalty=1.0,
        early_stopping=True
    )

    preds = [
        tokenizer.decode(
            gen_id,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True
        ) for gen_id in generated_ids
    ]

    return "".join(preds), definitions


def summarize(text, load_from_huggingface=True):
    """
    Summarizes the given License text

    Parameters
    ----------
    text : str
        Preprocessed License text.
    load_from_huggingface : boolean
        Toggles whether or not to load the model from huggingface. If set to
        False, this will load or train the model locally.

    Returns
    -------
    summary : str
        Summary of the License text.
    definitions : str
        Definitions extracted from the License text.

    """

    if load_from_huggingface:
        print("Loading Model from HuggingFace...")
        CUSTOM_MODEL_NAME = "utkarshsaboo45/ClearlyDefinedLicenseSummarizer"
        model = AutoModelForSeq2SeqLM.from_pretrained(CUSTOM_MODEL_NAME).to(device)
        tokenizer = AutoTokenizer.from_pretrained(CUSTOM_MODEL_NAME)
    else:
        if os.path.exists(MODEL_PATH + MODEL_FILENAME):
            print("Loading Model...")
            model = T5ForConditionalGeneration.from_pretrained(MODEL_NAME, return_dict=True).to(device)
            TOKENIZER = T5Tokenizer.from_pretrained(MODEL_NAME)
            model.load_state_dict(torch.load(MODEL_PATH + MODEL_FILENAME))
            model.eval()
        else:
            print("Training model...")
            if not os.path.exists(MODEL_PATH):
                os.makedirs(MODEL_PATH)
            train_dataloader, _ = prepare_dataloaders()
            model = train_and_save_model(train_dataloader, MODEL_PATH + MODEL_FILENAME)
        tokenizer = TOKENIZER

    summary, definitions = summarize_text_with_model(text, model, tokenizer)

    return summary, definitions


def summarize_license_files(path):
    """
    Summarize License files from paths and save them as summary text files.

    Parameters
    ----------
    path : list(str)
        A list of paths of the License files.

    """
    paths = glob.glob(path + "*.txt")

    for license_path in paths:
        with open(license_path, "r", encoding="utf-8") as f:
            summary, _ = summarize(f.read())
        with open(license_path.replace(".txt", "") + "__summary.txt", "w", encoding="utf-8") as f:
            f.write(summary)