Spaces:
Runtime error
Runtime error
Anon Anon
commited on
Commit
·
8eee1b1
1
Parent(s):
43d49fa
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# %%
|
2 |
+
import gradio as gr
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import random
|
7 |
+
from matplotlib.ticker import MaxNLocator
|
8 |
+
from transformers import pipeline
|
9 |
+
from winogender_sentences import get_sentences
|
10 |
+
|
11 |
+
OWN_MODEL_NAME = 'add-a-model'
|
12 |
+
PICK_YOUR_OWN_LABEL = 'pick-your-own'
|
13 |
+
|
14 |
+
MODEL_NAME_DICT = {
|
15 |
+
"roberta-large": "RoBERTa-large",
|
16 |
+
"bert-large-uncased": "BERT-large",
|
17 |
+
"roberta-base": "RoBERTa-base",
|
18 |
+
"bert-base-uncased": "BERT-base",
|
19 |
+
OWN_MODEL_NAME: "Your model's"
|
20 |
+
}
|
21 |
+
MODEL_NAMES = list(MODEL_NAME_DICT.keys())
|
22 |
+
|
23 |
+
|
24 |
+
DECIMAL_PLACES = 1
|
25 |
+
EPS = 1e-5 # to avoid /0 errors
|
26 |
+
NUM_PTS_TO_AVERAGE = 2
|
27 |
+
|
28 |
+
# Example date conts
|
29 |
+
DATE_SPLIT_KEY = "DATE"
|
30 |
+
START_YEAR = 1901
|
31 |
+
STOP_YEAR = 2016
|
32 |
+
NUM_PTS = 30
|
33 |
+
DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
|
34 |
+
DATES = [f'{d}' for d in DATES]
|
35 |
+
|
36 |
+
GENDERED_LIST = [
|
37 |
+
['he', 'she'],
|
38 |
+
['him', 'her'],
|
39 |
+
['his', 'hers'],
|
40 |
+
["himself", "herself"],
|
41 |
+
['male', 'female'],
|
42 |
+
# ['man', 'woman'] Explicitly added in winogender extended sentences
|
43 |
+
['men', 'women'],
|
44 |
+
["husband", "wife"],
|
45 |
+
['father', 'mother'],
|
46 |
+
['boyfriend', 'girlfriend'],
|
47 |
+
['brother', 'sister'],
|
48 |
+
["actor", "actress"],
|
49 |
+
]
|
50 |
+
|
51 |
+
|
52 |
+
# %%
|
53 |
+
# Fire up the models
|
54 |
+
models = {m : pipeline("fill-mask", model=m) for m in MODEL_NAMES if m != OWN_MODEL_NAME}
|
55 |
+
|
56 |
+
# %%
|
57 |
+
# Get the winogender sentences
|
58 |
+
winogender_sentences = get_sentences()
|
59 |
+
occs = sorted(list({sentence_id.split('_')[0]
|
60 |
+
for sentence_id in winogender_sentences}))
|
61 |
+
|
62 |
+
# %%
|
63 |
+
def get_gendered_token_ids():
|
64 |
+
male_gendered_tokens = [list[0] for list in GENDERED_LIST]
|
65 |
+
female_gendered_tokens = [list[1] for list in GENDERED_LIST]
|
66 |
+
|
67 |
+
return male_gendered_tokens, female_gendered_tokens
|
68 |
+
|
69 |
+
|
70 |
+
def get_winogender_texts(occ):
|
71 |
+
return [winogender_sentences[id] for id in winogender_sentences.keys() if id.split('_')[0] == occ]
|
72 |
+
|
73 |
+
|
74 |
+
def display_input_texts(occ, alt_text):
|
75 |
+
if occ == PICK_YOUR_OWN_LABEL:
|
76 |
+
texts = alt_text.split('\n')
|
77 |
+
else:
|
78 |
+
texts = get_winogender_texts(occ)
|
79 |
+
|
80 |
+
display_texts = [
|
81 |
+
f"{i+1}) {text}" for (i, text) in enumerate(texts)]
|
82 |
+
return "\n".join(display_texts), texts
|
83 |
+
|
84 |
+
|
85 |
+
def get_avg_prob_from_pipeline_outputs(pipeline_preds, gendered_tokens, num_preds):
|
86 |
+
pronoun_preds = [sum([
|
87 |
+
pronoun["score"] if pronoun["token_str"].strip(
|
88 |
+
).lower() in gendered_tokens else 0.0
|
89 |
+
for pronoun in top_preds])
|
90 |
+
for top_preds in pipeline_preds
|
91 |
+
]
|
92 |
+
return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
|
93 |
+
|
94 |
+
|
95 |
+
def is_top_pred_gendered(pipeline_preds, gendered_tokens):
|
96 |
+
return pipeline_preds[0][0]['token_str'].strip().lower() in gendered_tokens
|
97 |
+
|
98 |
+
# %%
|
99 |
+
|
100 |
+
|
101 |
+
def get_figure(df, model_name, occ):
|
102 |
+
xs = df[df.columns[0]]
|
103 |
+
ys = df[df.columns[1]]
|
104 |
+
|
105 |
+
fig, ax = plt.subplots()
|
106 |
+
ax.bar(xs, ys)
|
107 |
+
ax.axis('tight')
|
108 |
+
ax.set_xlabel("Sentence number")
|
109 |
+
ax.set_ylabel("Uncertainty metric")
|
110 |
+
ax.set_title(f"{MODEL_NAME_DICT[model_name]} gender pronoun uncertainty in '{occ}' sentences")
|
111 |
+
return fig
|
112 |
+
|
113 |
+
|
114 |
+
# %%
|
115 |
+
def predict_gender_pronouns(
|
116 |
+
model_name,
|
117 |
+
own_model_name,
|
118 |
+
texts,
|
119 |
+
occ,
|
120 |
+
):
|
121 |
+
"""Run inference on input_text for selected model type, returning uncertainty results.
|
122 |
+
"""
|
123 |
+
|
124 |
+
# TODO: make these selectable by user
|
125 |
+
indie_vars = ', '.join(DATES)
|
126 |
+
num_ave = NUM_PTS_TO_AVERAGE
|
127 |
+
|
128 |
+
# For debugging
|
129 |
+
print('input_texts', texts)
|
130 |
+
|
131 |
+
if model_name is None or model_name == '':
|
132 |
+
model_name = MODEL_NAMES[0]
|
133 |
+
model = models[model_name]
|
134 |
+
elif model_name == OWN_MODEL_NAME:
|
135 |
+
model = pipeline("fill-mask", model=own_model_name)
|
136 |
+
else:
|
137 |
+
model = models[model_name]
|
138 |
+
|
139 |
+
mask_token = model.tokenizer.mask_token
|
140 |
+
|
141 |
+
indie_vars_list = indie_vars.split(',')
|
142 |
+
|
143 |
+
male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids()
|
144 |
+
|
145 |
+
masked_texts = [text.replace('MASK', mask_token) for text in texts]
|
146 |
+
|
147 |
+
all_uncertainty_f = {}
|
148 |
+
not_top_gendered = set()
|
149 |
+
|
150 |
+
for i, text in enumerate(masked_texts):
|
151 |
+
female_pronoun_preds = []
|
152 |
+
male_pronoun_preds = []
|
153 |
+
top_pred_gendered = True # Assume true unless told otherwise
|
154 |
+
print(f"{i+1}) {text}")
|
155 |
+
for indie_var in indie_vars_list[:num_ave] + indie_vars_list[-num_ave:]:
|
156 |
+
|
157 |
+
target_text = f"In {indie_var}: {text}"
|
158 |
+
|
159 |
+
pipeline_preds = model(target_text)
|
160 |
+
# Quick hack as realized return type based on how many MASKs in text.
|
161 |
+
if type(pipeline_preds[0]) is not list:
|
162 |
+
pipeline_preds = [pipeline_preds]
|
163 |
+
|
164 |
+
# If top-pred not gendered, record as such
|
165 |
+
if not is_top_pred_gendered(pipeline_preds, female_gendered_tokens + male_gendered_tokens):
|
166 |
+
top_pred_gendered = False
|
167 |
+
|
168 |
+
num_preds = 1 # By design
|
169 |
+
female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
|
170 |
+
pipeline_preds,
|
171 |
+
female_gendered_tokens,
|
172 |
+
num_preds
|
173 |
+
))
|
174 |
+
male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
|
175 |
+
pipeline_preds,
|
176 |
+
male_gendered_tokens,
|
177 |
+
num_preds
|
178 |
+
))
|
179 |
+
|
180 |
+
# Normalizing by all gendered predictions
|
181 |
+
total_gendered_probs = np.add(
|
182 |
+
female_pronoun_preds, male_pronoun_preds)
|
183 |
+
|
184 |
+
norm_female_pronoun_preds = np.around(
|
185 |
+
np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100,
|
186 |
+
decimals=DECIMAL_PLACES
|
187 |
+
)
|
188 |
+
sent_idx = f"{i+1}" if top_pred_gendered else f"{i+1}*"
|
189 |
+
all_uncertainty_f[sent_idx] = round(abs((sum(norm_female_pronoun_preds[-num_ave:]) - sum(norm_female_pronoun_preds[:num_ave]))
|
190 |
+
/ num_ave), DECIMAL_PLACES)
|
191 |
+
|
192 |
+
uncertain_df = pd.DataFrame.from_dict(
|
193 |
+
all_uncertainty_f, orient='index', columns=['Uncertainty metric'])
|
194 |
+
|
195 |
+
uncertain_df = uncertain_df.reset_index().rename(
|
196 |
+
columns={'index': 'Sentence number'})
|
197 |
+
|
198 |
+
return (
|
199 |
+
target_text,
|
200 |
+
uncertain_df,
|
201 |
+
get_figure(uncertain_df, model_name, occ),
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
demo = gr.Blocks()
|
206 |
+
with demo:
|
207 |
+
input_texts = gr.Variable([])
|
208 |
+
gr.Markdown("## Are you certain?")
|
209 |
+
gr.Markdown(
|
210 |
+
"#### LLMs are pretty good at reporting their uncertainty. We just need to ask the right way.")
|
211 |
+
gr.Markdown("Using our uncertainty metric informed by applying causal inference techniques in \
|
212 |
+
[Our ICLR paper under review](https://openreview.net/pdf?id=25VgHaPz0l4), \
|
213 |
+
we are able to identify likely spurious correlations and exploit them in \
|
214 |
+
the scenario of gender underspecified tasks. (Note that introspecting softmax probabilities alone is insufficient, as in the sentences \
|
215 |
+
below, LLMs may report a softmax prob of ~0.9 despite the task being underspecified.)")
|
216 |
+
|
217 |
+
gr.Markdown("We extend the [Winogender Schemas](https://github.com/rudinger/winogender-schemas) evaluation set to produce\
|
218 |
+
eight syntactically similar sentences. However semantically, \
|
219 |
+
only two of the sentences are gender-specified while the rest remain gender-underspecified")
|
220 |
+
gr.Markdown("If a model can reliably tell us when it is uncertain about its predictions, one can replace only those uncertain predictions with\
|
221 |
+
an appropriate heuristic.")
|
222 |
+
|
223 |
+
with gr.Row():
|
224 |
+
model_name = gr.Radio(
|
225 |
+
MODEL_NAMES,
|
226 |
+
type="value",
|
227 |
+
label="Pick a preloaded BERT-like model for uncertainty evaluation (note: BERT-base performance least consistent)...",
|
228 |
+
)
|
229 |
+
own_model_name = gr.Textbox(
|
230 |
+
label=f"...Or, if you selected an '{OWN_MODEL_NAME}' model, put any Hugging Face pipeline model name \
|
231 |
+
(that supports the `fill-mask` task (see list at https://huggingface.co/models?pipeline_tag=fill-mask).",
|
232 |
+
)
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
occ_box = gr.Radio(
|
236 |
+
occs+[PICK_YOUR_OWN_LABEL], label=f"Pick an Occupation type from the Winogender Schemas evaluation set, or select '{PICK_YOUR_OWN_LABEL}'\
|
237 |
+
(it need not be about an occupation).")
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
alt_input_texts = gr.Textbox(
|
241 |
+
lines=2,
|
242 |
+
label=f"...If you selected '{PICK_YOUR_OWN_LABEL}' above, add your own texts new-line delimited sentences here. Be sure\
|
243 |
+
to include a single MASK-ed out pronoun. \
|
244 |
+
If unsure on the required format, click an occupation above instead, to see some example input texts for this round.",
|
245 |
+
)
|
246 |
+
|
247 |
+
with gr.Row():
|
248 |
+
get_text_btn = gr.Button("Load input texts")
|
249 |
+
|
250 |
+
get_text_btn.click(
|
251 |
+
fn=display_input_texts,
|
252 |
+
inputs=[occ_box, alt_input_texts],
|
253 |
+
outputs=[gr.Textbox(
|
254 |
+
label='Numbered sentences for evaluation. Number below corresponds to number in x-axis of plot.'), input_texts],
|
255 |
+
|
256 |
+
)
|
257 |
+
|
258 |
+
with gr.Row():
|
259 |
+
uncertain_btn = gr.Button("Get uncertainty results!")
|
260 |
+
gr.Markdown(
|
261 |
+
"If there is an * by a sentence number, then at least one top prediction for that sentence was non-gendered.")
|
262 |
+
|
263 |
+
with gr.Row():
|
264 |
+
female_fig = gr.Plot(type="auto")
|
265 |
+
with gr.Row():
|
266 |
+
female_df = gr.Dataframe()
|
267 |
+
with gr.Row():
|
268 |
+
display_text = gr.Textbox(
|
269 |
+
type="auto", label="Sample of text fed to model")
|
270 |
+
|
271 |
+
uncertain_btn.click(
|
272 |
+
fn=predict_gender_pronouns,
|
273 |
+
inputs=[model_name, own_model_name, input_texts, occ_box],
|
274 |
+
# inputs=date_example,
|
275 |
+
outputs=[display_text, female_df, female_fig]
|
276 |
+
)
|
277 |
+
|
278 |
+
demo.launch(debug=True)
|
279 |
+
|
280 |
+
# %%
|