import streamlit as st
import pandas as pd
from transformers import pipeline
from stqdm import stqdm
from simplet5 import SimpleT5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from transformers import BertTokenizer, TFBertForSequenceClassification
from datetime import datetime
import logging
from transformers import TextClassificationPipeline
import gc
from datasets import load_dataset
from utils.openllmapi.api import ChatBot
from utils.openllmapi.exceptions import *
import time
from typing import List
from collections import OrderedDict

tokenizer_kwargs = dict(
    max_length=128,
    truncation=True,
    padding=True,
)
SLEEP = 2


def cleanMemory(obj: TextClassificationPipeline):
    del obj
    gc.collect()


@st.cache_data
def getAllCats():
    data = load_dataset("ashhadahsan/amazon_theme")
    data = data["train"].to_pandas()
    labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
    del data
    return labels


@st.cache_data
def getAllSubCats():
    data = load_dataset("ashhadahsan/amazon_theme")
    data = data["train"].to_pandas()
    labels = [x for x in list(set(data.iloc[:, 1].values.tolist())) if x != "Unknown"]
    del data
    return labels


def assignHF(bot, what: str, to: str, old: List):
    try:
        old = ", ".join(old)
        message_content = bot.chat(
            f"""'Assign a one-line {what} to this summary of the text of a review
        {to}
        already assigned themes are , {old}
    theme""",
        )
        try:
            return message_content.split(":")[1].strip()
        except:
            return message_content.strip()
    except ChatError:
        return ""


@st.cache_resource
def loadZeroShotClassification():
    classifierzero = pipeline(
        "zero-shot-classification", model="facebook/bart-large-mnli"
    )
    return classifierzero


def assignZeroShot(zero, to: str, old: List):
    assigned = zero(to, old)
    assigneddict = dict(zip(assigned["labels"], assigned["scores"]))
    od = OrderedDict(sorted(assigneddict.items(), key=lambda x: x[1], reverse=True))
    print(list(od.keys())[0])
    print(type(list(od.keys())[0]))

    return list(od.keys())[0]


date = datetime.now().strftime(r"%Y-%m-%d")


@st.cache_resource
def load_t5() -> (AutoModelForSeq2SeqLM, AutoTokenizer):
    model = AutoModelForSeq2SeqLM.from_pretrained("t5-base")

    tokenizer = AutoTokenizer.from_pretrained("t5-base")
    return model, tokenizer


@st.cache_resource
def summarizationModel():
    return pipeline("summarization", model="my_awesome_sum/")


@st.cache_resource
def convert_df(df: pd.DataFrame):
    # IMPORTANT: Cache the conversion to prevent computation on every rerun
    return df.to_csv(index=False).encode("utf-8")


def load_one_line_summarizer(model):
    return model.load_model("t5", "snrspeaks/t5-one-line-summary")


@st.cache_resource
def classify_theme() -> TextClassificationPipeline:
    tokenizer = BertTokenizer.from_pretrained(
        "ashhadahsan/amazon-theme-bert-base-finetuned"
    )
    model = TFBertForSequenceClassification.from_pretrained(
        "ashhadahsan/amazon-theme-bert-base-finetuned"
    )
    pipeline = TextClassificationPipeline(
        model=model, tokenizer=tokenizer, top_k=1, **tokenizer_kwargs
    )
    return pipeline


@st.cache_resource
def classify_sub_theme() -> TextClassificationPipeline:
    tokenizer = BertTokenizer.from_pretrained(
        "ashhadahsan/amazon-subtheme-bert-base-finetuned"
    )
    model = TFBertForSequenceClassification.from_pretrained(
        "ashhadahsan/amazon-subtheme-bert-base-finetuned"
    )
    pipeline = TextClassificationPipeline(
        model=model, tokenizer=tokenizer, top_k=1, **tokenizer_kwargs
    )
    return pipeline


st.set_page_config(layout="wide", page_title="Amazon Review | Summarizer")
st.title("Amazon Review Summarizer")

uploaded_file = st.file_uploader("Choose a file", type=["xlsx", "xls", "csv"])

try:
    bot = ChatBot(
        cookies={
            "hf-chat": st.secrets["hf-chat"],
            "token": st.secrets["token"],
        }
    )
except ChatBotInitError as e:
    print(e)

summarizer_option = st.selectbox(
    "Select Summarizer",
    ("Custom trained on the dataset", "t5-base", "t5-one-line-summary"),
)
col1, col2, col3 = st.columns([1, 1, 1])

with col1:
    summary_yes = st.checkbox("Summrization", value=False)

with col2:
    classification = st.checkbox("Classify Category", value=True)

with col3:
    sub_theme = st.checkbox("Sub theme classification", value=True)

treshold = st.slider(
    label="Model Confidence value",
    min_value=0.1,
    max_value=0.8,
    step=0.1,
    value=0.6,
    help="If the model has a confidence score below this number , then a new label is assigned (0.6) means 60 percent and so on",
)

ps = st.empty()

if st.button("Process", type="primary"):
    themes = getAllCats()
    subthemes = getAllSubCats()
    # st.write(themes)

    oneline = SimpleT5()
    load_one_line_summarizer(model=oneline)
    zeroline = loadZeroShotClassification()

    cancel_button = st.empty()
    cancel_button2 = st.empty()
    cancel_button3 = st.empty()
    if uploaded_file is not None:
        if uploaded_file.name.split(".")[-1] in ["xls", "xlsx"]:
            df = pd.read_excel(uploaded_file, engine="openpyxl")
        if uploaded_file.name.split(".")[-1] in [".csv"]:
            df = pd.read_csv(uploaded_file)
        columns = df.columns.values.tolist()
        columns = [x.lower() for x in columns]
        df.columns = columns
        print(summarizer_option)
        outputdf = pd.DataFrame()
        try:
            text = df["text"].values.tolist()[0:100]
            outputdf["text"] = text
            if summarizer_option == "Custom trained on the dataset":
                if summary_yes:
                    model = summarizationModel()

                    progress_text = "Summarization in progress. Please wait."
                    summary = []

                    for x in stqdm(range(len(text))):
                        if cancel_button.button("Cancel", key=x):
                            del model
                            break
                        try:
                            summary.append(
                                model(
                                    f"summarize: {text[x]}",
                                    max_length=50,
                                    early_stopping=True,
                                )[0]["summary_text"]
                            )
                        except:
                            pass
                    outputdf["summary"] = summary
                    del model
                if classification:
                    themePipe = classify_theme()
                    classes = []
                    classesUnlabel = []
                    classesUnlabelZero = []
                    for x in stqdm(
                        text,
                        desc="Assigning Themes ...",
                        total=len(text),
                        colour="#BF1A1A",
                    ):
                        output = themePipe(x)[0][0]["label"]
                        classes.append(output)
                        score = round(themePipe(x)[0][0]["score"], 2)
                        if score <= treshold:
                            onelineoutput=oneline.predict(x)[0]
                            time.sleep(SLEEP)
                            print("hit")
                            classesUnlabel.append(
                                assignHF(
                                    bot=bot,
                                    what="theme",
                                    to=onelineoutput,
                                    old=themes,
                                )
                            ) 
                            classesUnlabelZero.append(
                                assignZeroShot(
                                    zero=zeroline, to=onelineoutput, old=themes
                                )
                            )

                        else:
                            classesUnlabel.append("")
                            classesUnlabelZero.append("")

                    outputdf["Review Theme"] = classes
                    outputdf["Review Theme-issue-new"] = classesUnlabel
                    outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
                    cleanMemory(themePipe)
                if sub_theme:
                    subThemePipe = classify_sub_theme()
                    classes = []
                    classesUnlabel = []
                    classesUnlabelZero = []
                    for x in stqdm(
                        text,
                        desc="Assigning Subthemes ...",
                        total=len(text),
                        colour="green",
                    ):
                        output = subThemePipe(x)[0][0]["label"]
                        classes.append(output)
                        score = round(subThemePipe(x)[0][0]["score"], 2)
                        if score <= treshold:
                            onelineoutput=oneline.predict(x)[0]

                            time.sleep(SLEEP)

                            print("hit")
                            classesUnlabel.append(
                                assignHF(
                                    bot=bot,
                                    what="subtheme",
                                    to=onelineoutput,
                                    old=subthemes,
                                )
                            )
                            classesUnlabelZero.append(
                                assignZeroShot(
                                    zero=zeroline,
                                    to=onelineoutput,
                                    old=subthemes,
                                )
                            )

                        else:
                            classesUnlabel.append("")
                            classesUnlabelZero.append("")

                    outputdf["Review SubTheme"] = classes
                    outputdf["Review SubTheme-issue-new"] = classesUnlabel
                    outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero

                    cleanMemory(subThemePipe)

                csv = convert_df(outputdf)
                st.download_button(
                    label="Download output as CSV",
                    data=csv,
                    file_name=f"{summarizer_option}_{date}_df.csv",
                    mime="text/csv",
                    use_container_width=True,
                )
            if summarizer_option == "t5-base":
                if summary_yes:
                    model, tokenizer = load_t5()
                    summary = []
                    for x in stqdm(range(len(text))):
                        if cancel_button2.button("Cancel", key=x):
                            del model, tokenizer
                            break
                        tokens_input = tokenizer.encode(
                            "summarize: " + text[x],
                            return_tensors="pt",
                            max_length=tokenizer.model_max_length,
                            truncation=True,
                        )
                        summary_ids = model.generate(
                            tokens_input,
                            min_length=80,
                            max_length=150,
                            length_penalty=20,
                            num_beams=2,
                        )
                        summary_gen = tokenizer.decode(
                            summary_ids[0], skip_special_tokens=True
                        )
                        summary.append(summary_gen)
                    del model, tokenizer
                    outputdf["summary"] = summary

                if classification:
                    themePipe = classify_theme()
                    classes = []
                    classesUnlabel = []
                    classesUnlabelZero = []
                    for x in stqdm(
                        text, desc="Assigning Themes ...", total=len(text), colour="red"
                    ):
                        output = themePipe(x)[0][0]["label"]
                        classes.append(output)
                        score = round(themePipe(x)[0][0]["score"], 2)
                        if score <= treshold:
                            onelineoutput=oneline.predict(x)[0]

                            print("hit")
                            time.sleep(SLEEP)

                            classesUnlabel.append(
                                assignHF(
                                    bot=bot,
                                    what="theme",
                                    to=onelineoutput,
                                    old=themes,
                                )
                            )
                            classesUnlabelZero.append(
                                assignZeroShot(
                                    zero=zeroline, to=onelineoutput, old=themes
                                )
                            )

                        else:
                            classesUnlabel.append("")
                            classesUnlabelZero.append("")
                    outputdf["Review Theme"] = classes
                    outputdf["Review Theme-issue-new"] = classesUnlabel
                    outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero
                    cleanMemory(themePipe)

                if sub_theme:
                    subThemePipe = classify_sub_theme()
                    classes = []
                    classesUnlabelZero = []

                    for x in stqdm(
                        text,
                        desc="Assigning Subthemes ...",
                        total=len(text),
                        colour="green",
                    ):
                        output = subThemePipe(x)[0][0]["label"]
                        classes.append(output)
                        score = round(subThemePipe(x)[0][0]["score"], 2)
                        if score <= treshold:
                            onelineoutput=oneline.predict(x)[0]

                            time.sleep(SLEEP)
                            print("hit")
                            classesUnlabel.append(
                                assignHF(
                                    bot=bot,
                                    what="subtheme",
                                    to=onelineoutput,
                                    old=subthemes,
                                )
                            )
                            classesUnlabelZero.append(
                                assignZeroShot(
                                    zero=zeroline,
                                    to=onelineoutput,
                                    old=subthemes,
                                )
                            )

                        else:
                            classesUnlabel.append("")
                            classesUnlabelZero.append("")

                    outputdf["Review SubTheme"] = classes
                    outputdf["Review SubTheme-issue-new"] = classesUnlabel
                    outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero

                    cleanMemory(subThemePipe)

                csv = convert_df(outputdf)
                st.download_button(
                    label="Download output as CSV",
                    data=csv,
                    file_name=f"{summarizer_option}_{date}_df.csv",
                    mime="text/csv",
                    use_container_width=True,
                )

            if summarizer_option == "t5-one-line-summary":
                if summary_yes:
                    model = SimpleT5()
                    load_one_line_summarizer(model=model)

                    summary = []
                    for x in stqdm(range(len(text))):
                        if cancel_button3.button("Cancel", key=x):
                            del model
                            break
                        try:
                            summary.append(model.predict(text[x])[0])
                        except:
                            pass
                    outputdf["summary"] = summary
                    del model

                if classification:
                    themePipe = classify_theme()
                    classes = []
                    classesUnlabel = []
                    classesUnlabelZero = []
                    for x in stqdm(
                        text, desc="Assigning Themes ...", total=len(text), colour="red"
                    ):
                        output = themePipe(x)[0][0]["label"]
                        classes.append(output)
                        score = round(themePipe(x)[0][0]["score"], 2)
                        if score <= treshold:
                            onelineoutput=oneline.predict(x)[0]

                            time.sleep(SLEEP)

                            print("hit")
                            classesUnlabel.append(
                                assignHF(
                                    bot=bot,
                                    what="theme",
                                    to=onelineoutput,
                                    old=themes,
                                )
                            )
                            classesUnlabelZero.append(
                                assignZeroShot(
                                    zero=zeroline, to=onelineoutput, old=themes
                                )
                            )

                        else:
                            classesUnlabel.append("")
                            classesUnlabelZero.append("")
                    outputdf["Review Theme"] = classes
                    outputdf["Review Theme-issue-new"] = classesUnlabel
                    outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero

                if sub_theme:
                    subThemePipe = classify_sub_theme()
                    classes = []
                    classesUnlabelZero = []

                    for x in stqdm(
                        text,
                        desc="Assigning Subthemes ...",
                        total=len(text),
                        colour="green",
                    ):
                        output = subThemePipe(x)[0][0]["label"]
                        classes.append(output)
                        score = round(subThemePipe(x)[0][0]["score"], 2)
                        if score <= treshold:
                            print("hit")
                            onelineoutput=oneline.predict(x)[0]

                            time.sleep(SLEEP)
                            classesUnlabel.append(
                                assignHF(
                                    bot=bot,
                                    what="subtheme",
                                    to=onelineoutput,
                                    old=subthemes,
                                )
                            )
                            classesUnlabelZero.append(
                                assignZeroShot(
                                    zero=zeroline,
                                    to=onelineoutput,
                                    old=subthemes,
                                )
                            )

                        else:
                            classesUnlabel.append("")
                            classesUnlabelZero.append("")

                    outputdf["Review SubTheme"] = classes
                    outputdf["Review SubTheme-issue-new"] = classesUnlabel
                    outputdf["Review SubTheme-issue-zero"] = classesUnlabelZero

                    cleanMemory(subThemePipe)

                csv = convert_df(outputdf)
                st.download_button(
                    label="Download output as CSV",
                    data=csv,
                    file_name=f"{summarizer_option}_{date}_df.csv",
                    mime="text/csv",
                    use_container_width=True,
                )

        except KeyError as e:
            st.error(
                "Please Make sure that your data must have a column named text",
                icon="🚨",
            )
            st.info("Text column must have amazon reviews", icon="ℹ️")
            # st.exception(e)

        except BaseException as e:
            logging.exception("An exception was occurred")