File size: 6,104 Bytes
d5b2eed
 
bce177f
d5b2eed
 
 
 
 
 
bce177f
82b7df7
 
81a18f0
bce177f
 
82b7df7
 
 
 
bce177f
 
 
 
 
 
 
e3032e8
 
d5b2eed
 
 
 
 
 
bce177f
5a5a81e
e91bd7c
d5b2eed
 
 
 
 
 
 
 
e91bd7c
d5b2eed
e3032e8
 
d5b2eed
 
e3032e8
5a5a81e
 
d5b2eed
 
 
5a5a81e
fb34e92
d5b2eed
fb34e92
d5b2eed
 
fb34e92
e91bd7c
29025ba
40bc8d5
 
 
 
5a5a81e
40bc8d5
e91bd7c
 
d5b2eed
 
 
 
 
 
e3032e8
d5b2eed
 
 
 
 
 
bce177f
 
 
 
29025ba
bce177f
5a5a81e
 
 
bce177f
1d91315
d5b2eed
 
bce177f
 
 
 
 
 
d5b2eed
 
e91bd7c
 
bce177f
d5b2eed
a54b97e
bce177f
963adc8
 
 
 
40bc8d5
 
 
 
 
 
 
 
963adc8
40bc8d5
 
 
963adc8
 
 
1a555d0
963adc8
5a5a81e
 
1d91315
40bc8d5
bce177f
 
e91bd7c
d5b2eed
bce177f
29025ba
1d91315
029862d
d5b2eed
 
bce177f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.
import os
import random
from urllib.parse import parse_qs

import gradio as gr
import requests
from transformers import pipeline
from huggingface_hub import Repository
from dotenv import load_dotenv
from pathlib import Path
import json

# These variables are for storing the mturk HITs in a Hugging Face dataset.
if Path(".env").is_file():
    load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(
    local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)

# Now let's run the app!
pipe = pipeline("sentiment-analysis")

demo = gr.Blocks()

with demo:
    total_cnt = 2 # How many examples per HIT
    dummy = gr.Textbox(visible=False)  # dummy for passing assignmentId

    # We keep track of state as a JSON
    state_dict = {"assignmentId": "", "cnt": 0, "data": []}
    state = gr.JSON(state_dict, visible=False)

    gr.Markdown("# DADC in Gradio example")
    gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")

    state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")

    # Generate model prediction
    # Default model: distilbert-base-uncased-finetuned-sst-2-english
    def _predict(txt, tgt, state, dummy):
        pred = pipe(txt)[0]
        other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
        pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}

        pred["label"] = pred["label"].title()
        ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
        fooled = pred["label"] != tgt
        if fooled:
            ret += " You fooled the model! Well done!"
        else:
            ret += " You did not fool the model! Too bad, try again!"
        state["data"].append({"cnt": state["cnt"], "text": txt, "target": tgt, "model_pred": pred["label"], "fooled": fooled})
        state["cnt"] += 1

        done = state["cnt"] == total_cnt
        toggle_final_submit = gr.update(visible=done)
        toggle_example_submit = gr.update(visible=not done)
        new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"

        query = parse_qs(dummy[1:])
        if "assignmentId" in query:
            # It seems that someone is using this app on mturk. We need to
            # store the assignmentId in the state before submit_hit_button
            # is clicked. We can do this here in _predict. We need to save the
            # assignmentId so that the turker can get credit for their HIT.
            state["assignmentId"] = query["assignmentId"][0]

        return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md, dummy

    # Input fields
    text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
    labels = ["Positive", "Negative"]
    random.shuffle(labels)
    label_input = gr.Radio(choices=labels, label="Target (correct) label")
    label_output = gr.Label()
    text_output = gr.Markdown()
    with gr.Column() as example_submit:
        submit_ex_button = gr.Button("Submit")
    with gr.Column(visible=False) as final_submit:
        submit_hit_button = gr.Button("Submit HIT")

    # Store the HIT data into a Hugging Face dataset.
    # The HIT is also stored and logged on mturk when post_hit_js is run below.
    # This _store_in_huggingface_dataset function just demonstrates how easy it is
    # to automatically create a Hugging Face dataset from mturk.
    def _store_in_huggingface_dataset(state):
        with open(DATA_FILE, "a") as jsonlfile:
            json_data_with_assignment_id =\
                [json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]]
            jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
        repo.push_to_hub()
        return state

    # Button event handlers
    get_window_location_search_js = """
        function(text_input, label_input, state, dummy) {
            return [text_input, label_input, state, window.location.search];
        }
        """

    submit_ex_button.click(
        _predict,
        inputs=[text_input, label_input, state, dummy],
        outputs=[label_output, text_output, state, example_submit, final_submit, state_display, dummy],
        _js=get_window_location_search_js,
    )

    post_hit_js = """
        function(state) {
            if (state["assignmentId"] !== ""){
                // If there is an assignmentId, then the submitter is on mturk
                // and we need to submit their HIT.
                const form = document.createElement('form');
                form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
                form.method = 'post';
                for (const key in state) {
                    const hiddenField = document.createElement('input');
                    hiddenField.type = 'hidden';
                    hiddenField.name = key;
                    hiddenField.value = state[key];
                    form.appendChild(hiddenField);
                };
                document.body.appendChild(form);
                form.submit();
                return [state];
            } else {
                // If there is no assignmentId, then we assume that the submitter is
                // on huggingface.co and we can't submit a HIT to mturk. But
                // _store_in_huggingface_dataset will still store their example in
                // our dataset without an assignmentId. The following line here
                // loads the app again so the user can enter in another "fake" HIT.
                window.location.href = window.location.href;
            }
        }
        """

    submit_hit_button.click(
        _store_in_huggingface_dataset,
        inputs=[state],
        outputs=[state],
        _js=post_hit_js,
    )

demo.launch()