import os
import zipfile

import gradio as gr
import nltk
import pandas as pd
import requests

from pyabsa import TADCheckpointManager
from textattack.attack_recipes import (
    BAEGarg2019,
    PWWSRen2019,
    TextFoolerJin2019,
    PSOZang2020,
    IGAWang2019,
    GeneticAlgorithmAlzantot2018,
    DeepWordBugGao2018,
    CLARE2020,
)
from textattack.attack_results import SuccessfulAttackResult
from utils import SentAttacker, get_agnews_example, get_sst2_example, get_amazon_example, get_imdb_example, diff_texts
# from utils import get_yahoo_example

sent_attackers = {}
tad_classifiers = {}

attack_recipes = {
    "bae": BAEGarg2019,
    "pwws": PWWSRen2019,
    "textfooler": TextFoolerJin2019,
    "pso": PSOZang2020,
    "iga": IGAWang2019,
    "ga": GeneticAlgorithmAlzantot2018,
    "deepwordbug": DeepWordBugGao2018,
    "clare": CLARE2020,
}


def init():
    nltk.download("omw-1.4")

    if not os.path.exists("TAD-SST2"):
        z = zipfile.ZipFile("checkpoints.zip", "r")
        z.extractall(os.getcwd())

    for attacker in ["pwws", "bae", "textfooler", "deepwordbug"]:
        for dataset in [
            "agnews10k",
            "sst2",
            "MR",
            'imdb'
        ]:
            if "tad-{}".format(dataset) not in tad_classifiers:
                tad_classifiers[
                    "tad-{}".format(dataset)
                ] = TADCheckpointManager.get_tad_text_classifier(
                    "tad-{}".format(dataset).upper()
                )

            sent_attackers["tad-{}{}".format(dataset, attacker)] = SentAttacker(
                tad_classifiers["tad-{}".format(dataset)], attack_recipes[attacker]
            )
            tad_classifiers["tad-{}".format(dataset)].sent_attacker = sent_attackers[
                "tad-{}pwws".format(dataset)
            ]


cache = set()


def generate_adversarial_example(dataset, attacker, text=None, label=None):
    if not text or text in cache:
        if "agnews" in dataset.lower():
            text, label = get_agnews_example()
        elif "sst2" in dataset.lower():
            text, label = get_sst2_example()
        elif "MR" in dataset.lower():
            text, label = get_amazon_example()
        # elif "yahoo" in dataset.lower():
            # text, label = get_yahoo_example()
        elif "imdb" in dataset.lower():
            text, label = get_imdb_example()

    cache.add(text)

    result = None
    attack_result = sent_attackers[
        "tad-{}{}".format(dataset.lower(), attacker.lower())
    ].attacker.simple_attack(text, int(label))
    if isinstance(attack_result, SuccessfulAttackResult):
        if (
                attack_result.perturbed_result.output
                != attack_result.original_result.ground_truth_output
        ) and (
                attack_result.original_result.output
                == attack_result.original_result.ground_truth_output
        ):
            # with defense
            result = tad_classifiers["tad-{}".format(dataset.lower())].infer(
                attack_result.perturbed_result.attacked_text.text
                + "$LABEL${},{},{}".format(
                    attack_result.original_result.ground_truth_output,
                    1,
                    attack_result.perturbed_result.output,
                ),
                print_result=True,
                defense=attacker,
            )

    if result:
        classification_df = {}
        classification_df["is_repaired"] = result["is_fixed"]
        classification_df["pred_label"] = result["label"]
        classification_df["confidence"] = round(result["confidence"], 3)
        classification_df["is_correct"] = str(result["pred_label"]) == str(label)

        advdetection_df = {}
        if result["is_adv_label"] != "0":
            advdetection_df["is_adversarial"] = {
                "0": False,
                "1": True,
                0: False,
                1: True,
            }[result["is_adv_label"]]
            advdetection_df["perturbed_label"] = result["perturbed_label"]
            advdetection_df["confidence"] = round(result["is_adv_confidence"], 3)
            advdetection_df['ref_is_attack'] = result['ref_is_adv_label']
            advdetection_df['is_correct'] = result['ref_is_adv_check']

    else:
        return generate_adversarial_example(dataset, attacker)

    return (
        text,
        label,
        result["restored_text"],
        result["label"],
        attack_result.perturbed_result.attacked_text.text,
        diff_texts(text, text),
        diff_texts(text, attack_result.perturbed_result.attacked_text.text),
        diff_texts(text, result["restored_text"]),
        attack_result.perturbed_result.output,
        pd.DataFrame(classification_df, index=[0]),
        pd.DataFrame(advdetection_df, index=[0]),
    )


def run_demo(dataset, attacker, text=None, label=None):
    try:
        data = {
            "dataset": dataset,
            "attacker": attacker,
            "text": text,
            "label": label,
        }
        response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', json=data)
        result = response.json()
        print(response.json())
        return (
            result["text"],
            result["label"],
            result["restored_text"],
            result["result_label"],
            result["perturbed_text"],
            result["text_diff"],
            result["perturbed_diff"],
            result["restored_diff"],
            result["output"],
            pd.DataFrame(result["classification_df"]),
            pd.DataFrame(result["advdetection_df"]),
            result["message"]
        )
    except Exception as e:
        print(e)
        return generate_adversarial_example(dataset, attacker, text, label)


def check_gpu():
    try:
        response = requests.post('https://rpddemo.pagekite.me/api/generate_adversarial_example', timeout=3)
        if response.status_code < 500:
            return 'GPU available'
        else:
            return 'GPU not available'
    except Exception as e:
        return 'GPU not available'


if __name__ == "__main__":
    try:
        init()
    except Exception as e:
        print(e)
        print("Failed to initialize the demo. Please try again later.")

    demo = gr.Blocks()

    with demo:
        gr.Markdown("<h1 align='center'>Detection and Correction based on Word Importance Ranking (DCWIR) </h1>")
        gr.Markdown("<h2 align='center'>Clarifications</h2>")
        gr.Markdown("""
    - This demo has no mechanism to ensure the adversarial example will be correctly repaired by DCWIR. 
    - The adversarial example and corrected adversarial example may be unnatural to read, while it is because the attackers usually generate unnatural perturbations. 
    - All the proposed attacks are Black Box attack where the attacker has no access to the model parameters.
    """)
        gr.Markdown("<h2 align='center'>Natural Example Input</h2>")
        with gr.Group():
            with gr.Row():
                input_dataset = gr.Radio(
                    choices=["SST2", "IMDB", "MR", "AGNews10K"],
                    value="SST2",
                    label="Select a testing dataset and an adversarial attacker to generate an adversarial example.",
                )
                input_attacker = gr.Radio(
                    choices=["BAE", "PWWS", "TextFooler", "DeepWordBug"],
                    value="TextFooler",
                    label="Choose an Adversarial Attacker for generating an adversarial example to attack the model.",
                )
            with gr.Group(visible=True):

                with gr.Row():
                    input_sentence = gr.Textbox(
                        placeholder="Input a natural example...",
                        label="Alternatively, input a natural example and its original label (from above datasets) to generate an adversarial example.",

                    )
                    input_label = gr.Textbox(
                        placeholder="Original label, (must be a integer, because we use digits to represent labels in training)", 
                        label="Original Label",
                    )
                gr.Markdown(
                    "<h3 align='center'>Default parameters are set according to the main experiment setup in the report.</h2>",
                )
        with gr.Row():
            wir_percentage = gr.Textbox(
                placeholder="Enter percentage from WIR...",
                label="Percentage from WIR",
            )
            frequency_threshold = gr.Textbox(
                placeholder="Enter frequency threshold...",
                label="Frequency Threshold",
            )
            max_candidates = gr.Textbox(
                placeholder="Enter maximum number of candidates...",
                label="Maximum Number of Candidates",
            )
        msg_text = gr.Textbox(
            label="Message",
            placeholder="This is a message box to show any error messages.",
        )
        button_gen = gr.Button(
            "Generate an adversarial example to repair using DCWIR (GPU: < 1 minute, CPU: 1-10 minutes)",
            variant="primary",
        )
        gpu_status_text = gr.Textbox(
            label='GPU status',
            placeholder="Please click to check",
        )
        button_check = gr.Button(
            "Check if GPU available",
            variant="primary"
        )

        button_check.click(
            fn=check_gpu,
            inputs=[],
            outputs=[
                gpu_status_text
            ]
        )

        gr.Markdown("<h2 align='center'>Generated Adversarial Example and Repaired Adversarial Example</h2>")

        with gr.Column():
            with gr.Group():
                with gr.Row():
                    output_original_example = gr.Textbox(label="Original Example")
                    output_original_label = gr.Textbox(label="Original Label")
                with gr.Row():
                    output_adv_example = gr.Textbox(label="Adversarial Example")
                    output_adv_label = gr.Textbox(label="Predicted Label of the Adversarial Example")
                with gr.Row():
                    output_repaired_example = gr.Textbox(
                        label="Repaired Adversarial Example by Rapid"
                    )
                    output_repaired_label = gr.Textbox(label="Predicted Label of the Repaired Adversarial Example")

        gr.Markdown("<h2 align='center'>Example Difference (Comparisons)</p>")
        gr.Markdown("""
        <p align='center'>The (+) and (-) in the boxes indicate the added and deleted characters in the adversarial example compared to the original input natural example.</p>
            """)
        ori_text_diff = gr.HighlightedText(
            label="The Original Natural Example",
            combine_adjacent=True,
            show_legend=True,
        )
        adv_text_diff = gr.HighlightedText(
            label="Character Editions of Adversarial Example Compared to the Natural Example",
            combine_adjacent=True,
            show_legend=True,
        )

        restored_text_diff = gr.HighlightedText(
            label="Character Editions of Repaired Adversarial Example Compared to the Natural Example",
            combine_adjacent=True,
            show_legend=True,
        )

        gr.Markdown(
            "## <h2 align='center'>The Output of Reactive Perturbation Defocusing</p>"
        )
        with gr.Row():
            with gr.Column():
                with gr.Group():
                    output_is_adv_df = gr.DataFrame(
                        label="Adversarial Example Detection Result"
                    )
                    gr.Markdown(
                        """
                         - The is_adversarial field indicates if an adversarial example is detected.
                         - The perturbed_label is the predicted label of the adversarial example. 
                         - The confidence field represents the ratio of Inverted samples among the total number of generated candidates. 
                         """
                    )
            with gr.Column():
                with gr.Group():
                    output_df = gr.DataFrame(
                        label="Correction Classification Result"
                    )
                    gr.Markdown(
                      """
                        - If is_corrected=true, it has been Corrected by DCWIR. 
                        - The pred_label field indicates the standard classification result. 
                        - The confidence field represents ratio of the dominant class among all Inverted candidates.
                        - The is_correct field indicates whether the predicted label is correct.
                        
                        """
                    )

        # Bind functions to buttons
        button_gen.click(
            fn=run_demo,
            inputs=[input_dataset, input_attacker, input_sentence, input_label],
            outputs=[
                output_original_example,
                output_original_label,
                output_repaired_example,
                output_repaired_label,
                output_adv_example,
                ori_text_diff,
                adv_text_diff,
                restored_text_diff,
                output_adv_label,
                output_df,
                output_is_adv_df,
                msg_text
            ],
        )

    demo.queue(2).launch()