|
|
|
|
|
|
|
import random |
|
from urllib.parse import parse_qs |
|
|
|
import gradio as gr |
|
import requests |
|
from transformers import pipeline |
|
|
|
pipe = pipeline("sentiment-analysis") |
|
|
|
demo = gr.Blocks() |
|
|
|
with demo: |
|
total_cnt = 2 |
|
dummy = gr.Textbox(visible=False) |
|
|
|
|
|
state_dict = {"assignmentId": "", "cnt": 0, "fooled": 0, "data": [], "metadata": {}, "response": ""} |
|
state = gr.Variable(state_dict) |
|
|
|
gr.Markdown("# DADC in Gradio example") |
|
gr.Markdown("Try to fool the model and find an example where it predicts the wrong label!") |
|
|
|
state_display = gr.Markdown(f"State: 0/{total_cnt} (0 fooled)") |
|
|
|
|
|
|
|
def _predict(txt, tgt, state): |
|
pred = pipe(txt)[0] |
|
other_label = 'negative' if pred['label'].lower() == "positive" else "positive" |
|
pred_confidences = {pred['label'].lower(): pred['score'], other_label: 1 - pred['score']} |
|
|
|
pred["label"] = pred["label"].title() |
|
ret = f"Target: **{tgt}**. Model prediction: **{pred['label']}**\n\n" |
|
if pred["label"] != tgt: |
|
state["fooled"] += 1 |
|
ret += " You fooled the model! Well done!" |
|
else: |
|
ret += " You did not fool the model! Too bad, try again!" |
|
state["data"].append(ret) |
|
state["cnt"] += 1 |
|
|
|
done = state["cnt"] == total_cnt |
|
toggle_final_submit = gr.update(visible=done) |
|
toggle_example_submit = gr.update(visible=not done) |
|
new_state_md = f"State: {state['cnt']}/{total_cnt} ({state['fooled']} fooled)" |
|
return pred_confidences, ret, state, toggle_example_submit, toggle_final_submit, new_state_md |
|
|
|
|
|
text_input = gr.Textbox(placeholder="Enter model-fooling statement", show_label=False) |
|
labels = ["Positive", "Negative"] |
|
random.shuffle(labels) |
|
label_input = gr.Radio(choices=labels, label="Target (correct) label") |
|
label_output = gr.Label() |
|
text_output = gr.Markdown() |
|
with gr.Column() as example_submit: |
|
submit_ex_button = gr.Button("Submit") |
|
with gr.Column(visible=False) as final_submit: |
|
submit_hit_button = gr.Button("Submit HIT") |
|
|
|
|
|
|
|
def _submit(state, dummy): |
|
query = parse_qs(dummy[1:]) |
|
print('yo') |
|
assert "assignmentId" in query, "No assignment ID provided, unable to submit" |
|
print('yo2') |
|
state["assignmentId"] = query["assignmentId"] |
|
url = "https://workersandbox.mturk.com/mturk/externalSubmit" |
|
x = requests.post(url, data=state) |
|
print('yo3') |
|
print(x) |
|
|
|
return str(x) + " With assignmentId " + state["assignmentId"][0], state, dummy |
|
|
|
|
|
submit_ex_button.click( |
|
_predict, |
|
inputs=[text_input, label_input, state], |
|
outputs=[label_output, text_output, state, example_submit, final_submit, state_display], |
|
) |
|
|
|
response_output = gr.Markdown() |
|
submit_hit_button.click( |
|
_submit, |
|
inputs=[state, dummy], |
|
outputs=[response_output, state, dummy], |
|
_js="function(state, dummy) { console.log(window); return [state, window.location.search]; }", |
|
) |
|
|
|
|
|
demo.launch() |