Spaces:
Runtime error
Runtime error
File size: 6,104 Bytes
d5b2eed bce177f d5b2eed bce177f 82b7df7 81a18f0 bce177f 82b7df7 bce177f e3032e8 d5b2eed bce177f 5a5a81e e91bd7c d5b2eed e91bd7c d5b2eed e3032e8 d5b2eed e3032e8 5a5a81e d5b2eed 5a5a81e fb34e92 d5b2eed fb34e92 d5b2eed fb34e92 e91bd7c 29025ba 40bc8d5 5a5a81e 40bc8d5 e91bd7c d5b2eed e3032e8 d5b2eed bce177f 29025ba bce177f 5a5a81e bce177f 1d91315 d5b2eed bce177f d5b2eed e91bd7c bce177f d5b2eed a54b97e bce177f 963adc8 40bc8d5 963adc8 40bc8d5 963adc8 1a555d0 963adc8 5a5a81e 1d91315 40bc8d5 bce177f e91bd7c d5b2eed bce177f 29025ba 1d91315 029862d d5b2eed bce177f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.
import os
import random
from urllib.parse import parse_qs
import gradio as gr
import requests
from transformers import pipeline
from huggingface_hub import Repository
from dotenv import load_dotenv
from pathlib import Path
import json
# These variables are for storing the mturk HITs in a Hugging Face dataset.
if Path(".env").is_file():
load_dotenv(".env")
DATASET_REPO_URL = os.getenv("DATASET_REPO_URL")
HF_TOKEN = os.getenv("HF_TOKEN")
DATA_FILENAME = "data.jsonl"
DATA_FILE = os.path.join("data", DATA_FILENAME)
repo = Repository(
local_dir="data", clone_from=DATASET_REPO_URL, use_auth_token=HF_TOKEN
)
# Now let's run the app!
pipe = pipeline("sentiment-analysis")
demo = gr.Blocks()
with demo:
total_cnt = 2 # How many examples per HIT
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
# We keep track of state as a JSON
state_dict = {"assignmentId": "", "cnt": 0, "data": []}
state = gr.JSON(state_dict, visible=False)
gr.Markdown("# DADC in Gradio example")
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")
# Generate model prediction
# Default model: distilbert-base-uncased-finetuned-sst-2-english
def _predict(txt, tgt, state, dummy):
pred = pipe(txt)[0]
other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}
pred["label"] = pred["label"].title()
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
fooled = pred["label"] != tgt
if fooled:
ret += " You fooled the model! Well done!"
else:
ret += " You did not fool the model! Too bad, try again!"
state["data"].append({"cnt": state["cnt"], "text": txt, "target": tgt, "model_pred": pred["label"], "fooled": fooled})
state["cnt"] += 1
done = state["cnt"] == total_cnt
toggle_final_submit = gr.update(visible=done)
toggle_example_submit = gr.update(visible=not done)
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
query = parse_qs(dummy[1:])
if "assignmentId" in query:
# It seems that someone is using this app on mturk. We need to
# store the assignmentId in the state before submit_hit_button
# is clicked. We can do this here in _predict. We need to save the
# assignmentId so that the turker can get credit for their HIT.
state["assignmentId"] = query["assignmentId"][0]
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md, dummy
# Input fields
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
labels = ["Positive", "Negative"]
random.shuffle(labels)
label_input = gr.Radio(choices=labels, label="Target (correct) label")
label_output = gr.Label()
text_output = gr.Markdown()
with gr.Column() as example_submit:
submit_ex_button = gr.Button("Submit")
with gr.Column(visible=False) as final_submit:
submit_hit_button = gr.Button("Submit HIT")
# Store the HIT data into a Hugging Face dataset.
# The HIT is also stored and logged on mturk when post_hit_js is run below.
# This _store_in_huggingface_dataset function just demonstrates how easy it is
# to automatically create a Hugging Face dataset from mturk.
def _store_in_huggingface_dataset(state):
with open(DATA_FILE, "a") as jsonlfile:
json_data_with_assignment_id =\
[json.dumps(dict({"assignmentId": state["assignmentId"]}, **datum)) for datum in state["data"]]
jsonlfile.write("\n".join(json_data_with_assignment_id) + "\n")
repo.push_to_hub()
return state
# Button event handlers
get_window_location_search_js = """
function(text_input, label_input, state, dummy) {
return [text_input, label_input, state, window.location.search];
}
"""
submit_ex_button.click(
_predict,
inputs=[text_input, label_input, state, dummy],
outputs=[label_output, text_output, state, example_submit, final_submit, state_display, dummy],
_js=get_window_location_search_js,
)
post_hit_js = """
function(state) {
if (state["assignmentId"] !== ""){
// If there is an assignmentId, then the submitter is on mturk
// and we need to submit their HIT.
const form = document.createElement('form');
form.action = 'https://workersandbox.mturk.com/mturk/externalSubmit';
form.method = 'post';
for (const key in state) {
const hiddenField = document.createElement('input');
hiddenField.type = 'hidden';
hiddenField.name = key;
hiddenField.value = state[key];
form.appendChild(hiddenField);
};
document.body.appendChild(form);
form.submit();
return [state];
} else {
// If there is no assignmentId, then we assume that the submitter is
// on huggingface.co and we can't submit a HIT to mturk. But
// _store_in_huggingface_dataset will still store their example in
// our dataset without an assignmentId. The following line here
// loads the app again so the user can enter in another "fake" HIT.
window.location.href = window.location.href;
}
}
"""
submit_hit_button.click(
_store_in_huggingface_dataset,
inputs=[state],
outputs=[state],
_js=post_hit_js,
)
demo.launch() |