# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.
import os
import random
from urllib.parse import parse_qs

import gradio as gr
import requests
from transformers import pipeline
from huggingface_hub import Repository
from dotenv import load_dotenv
from pathlib import Path
import json
from filelock import FileLock

# These variables are for storing the mturk HITs in a Hugging Face dataset.
if Path(".env").is_file():
    load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(
    local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)

# Now let's run the app!
pipe = pipeline("sentiment-analysis")

demo = gr.Blocks()

with demo:
    total_cnt = 2 # How many examples per HIT
    dummy = gr.Textbox(visible=False)  # dummy for passing assignmentId

    # We keep track of state as a JSON
    state_dict = {"assignmentId": "", "cnt": 0, "cnt_fooled": 0, "data": []}
    state = gr.JSON(state_dict, visible=False)

    gr.Markdown("# DADC in Gradio example")
    gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")

    state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")

    # Generate model prediction
    # Default model: distilbert-base-uncased-finetuned-sst-2-english
    def _predict(txt, tgt, state, dummy):
        pred = pipe(txt)[0]
        other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
        pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}

        pred["label"] = pred["label"].title()
        ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
        fooled = pred["label"] != tgt
        if fooled:
            state["cnt_fooled"] += 1
            ret += " You fooled the model! Well done!"
        else:
            ret += " You did not fool the model! Too bad, try again!"
        state["cnt"] += 1

        done = state["cnt"] == total_cnt
        toggle_example_submit = gr.update(visible=not done)
        new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['cnt_fooled']} fooled)"

        state["data"].append({"cnt": state["cnt"], "text": txt, "target": tgt.lower(), "model_pred": pred["label"].lower(), "fooled": fooled})

        query = parse_qs(dummy[1:])
        if "assignmentId" in query and query["assignmentId"][0] != "ASSIGNMENT_ID_NOT_AVAILABLE":
            # It seems that someone is using this app on mturk. We need to
            # store the assignmentId in the state before submit_hit_button
            # is clicked. We can do this here in _predict. We need to save the
            # assignmentId so that the turker can get credit for their HIT.
            state["assignmentId"] = query["assignmentId"][0]
            toggle_final_submit = gr.update(visible=done)
            toggle_final_submit_preview = gr.update(visible=False)
        else:
            toggle_final_submit_preview = gr.update(visible=done)
            toggle_final_submit = gr.update(visible=False)

        return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, toggle_final_submit_preview, new_state_md, dummy

    # Input fields
    text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
    labels = ["Positive", "Negative"]
    random.shuffle(labels)
    label_input = gr.Radio(choices=labels, label="Target (correct) label")
    label_output = gr.Label()
    text_output = gr.Markdown()
    with gr.Column() as example_submit:
        submit_ex_button = gr.Button("Submit")
    with gr.Column(visible=False) as final_submit:
        submit_hit_button = gr.Button("Submit HIT")
    with gr.Column(visible=False) as final_submit_preview:
        submit_hit_button_preview = gr.Button("Submit Work (preview mode; no mturk HIT credit)")

    # Store the HIT data into a Hugging Face dataset.
    # The HIT is also stored and logged on mturk when post_hit_js is run below.
    # This _store_in_huggingface_dataset function just demonstrates how easy it is
    # to automatically create a Hugging Face dataset from mturk.
    def _store_in_huggingface_dataset(state):
        lock = FileLock(DATA_FILE + ".lock")
        lock.acquire()
        try:
            with open(DATA_FILE, "a") as jsonlfile:
                json_data_with_assignment_id =\
                    [json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]]
                jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
            repo.push_to_hub()
        finally:
            lock.release()
        return state

    # Button event handlers
    get_window_location_search_js = """
        function(text_input, label_input, state, dummy) {
            return [text_input, label_input, state, window.location.search];
        }
        """

    submit_ex_button.click(
        _predict,
        inputs=[text_input, label_input, state, dummy],
        outputs=[label_output, text_output, state, example_submit, final_submit, final_submit_preview, state_display, dummy],
        _js=get_window_location_search_js,
    )

    post_hit_js = """
        function(state) {
            // If there is an assignmentId, then the submitter is on mturk
            // and has accepted the HIT. So, we need to submit their HIT.
            const form = document.createElement('form');
            form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
            form.method = 'post';
            for (const key in state) {
                const hiddenField = document.createElement('input');
                hiddenField.type = 'hidden';
                hiddenField.name = key;
                hiddenField.value = state[key];
                form.appendChild(hiddenField);
            };
            document.body.appendChild(form);
            form.submit();
            return state;
        }
        """

    submit_hit_button.click(
        _store_in_huggingface_dataset,
        inputs=[state],
        outputs=[state],
        _js=post_hit_js,
    )

    refresh_app_js = """
        function(state) {
            // The following line here loads the app again so the user can
            // enter in another preview-mode "HIT".
            window.location.href = window.location.href;
            return state;
        }
        """

    submit_hit_button_preview.click(
        _store_in_huggingface_dataset,
        inputs=[state],
        outputs=[state],
        _js=refresh_app_js,
    )

demo.launch()