Spaces:
Runtime error
Runtime error
File size: 3,447 Bytes
d5b2eed e3032e8 d5b2eed 80bfae7 d5b2eed e3032e8 d5b2eed e3032e8 d5b2eed fb34e92 d5b2eed fb34e92 d5b2eed fb34e92 d5b2eed fb34e92 e3032e8 d5b2eed e3032e8 d5b2eed a868cbd fb34e92 7ed08ed fb34e92 d5b2eed e3032e8 d5b2eed a54b97e d5b2eed a54b97e e22cf04 d5b2eed 73f2dfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# Basic example for doing model-in-the-loop dynamic adversarial data collection
# using Gradio Blocks.
import random
from urllib.parse import parse_qs
import gradio as gr
import requests
from transformers import pipeline
pipe = pipeline("sentiment-analysis")
demo = gr.Blocks()
with demo:
total_cnt = 2 # How many examples per HIT
dummy = gr.Textbox(visible=False) # dummy for passing assignmentId
# We keep track of state as a Variable
state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}}
state = gr.Variable(state_dict)
gr.Markdown("# DADC in Gradio example")
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!")
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)")
# Generate model prediction
# Default model: distilbert-base-uncased-finetuned-sst-2-english
def _predict(txt, tgt, state):
pred = pipe(txt)[0]
other_label = 'negative' if pred['label'].lower() == "positive" else "positive"
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']}
pred["label"] = pred["label"].title()
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n"
if pred["label"] != tgt:
state["fooled"] += 1
ret += " You fooled the model! Well done!"
else:
ret += " You did not fool the model! Too bad, try again!"
state["data"].append(ret)
state["cnt"] += 1
done = state["cnt"] == total_cnt
toggle_final_submit = gr.update(visible=done)
toggle_example_submit = gr.update(visible=not done)
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)"
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md
# Input fields
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False)
labels = ["Positive", "Negative"]
random.shuffle(labels)
label_input = gr.Radio(choices=labels, label="Target (correct) label")
label_output = gr.Label()
text_output = gr.Markdown()
with gr.Column() as example_submit:
submit_ex_button = gr.Button("Submit")
with gr.Column(visible=False) as final_submit:
submit_hit_button = gr.Button("Submit HIT")
# Submit state to MTurk backend for ExternalQuestion
# Update the URL below to switch from Sandbox to real data collection
def _submit(state, dummy):
query = parse_qs(dummy[1:])
assert "assignmentId" in query, "No assignment ID provided, unable to submit"
state["assignmentId"] = query["assignmentId"][0]
url = "https://workersandbox.mturk.com/mturk/externalSubmit"
x = requests.post(url, data=state)
return str(x) + " With assignmentId " + state["assignmentId"] + "\n" + x.text, state, dummy
# Button event handlers
submit_ex_button.click(
_predict,
inputs=[text_input, label_input, state],
outputs=[label_output, text_output, state, example_submit, final_submit, state_display],
)
response_output = gr.Markdown()
submit_hit_button.click(
_submit,
inputs=[state, dummy],
outputs=[response_output, state, dummy],
_js="function(state, dummy) { console.log(window); return [state, window.location.search]; }",
)
demo.launch() |