Anon Anon commited on
Commit
8eee1b1
·
1 Parent(s): 43d49fa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +280 -0
app.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import gradio as gr
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import pandas as pd
6
+ import random
7
+ from matplotlib.ticker import MaxNLocator
8
+ from transformers import pipeline
9
+ from winogender_sentences import get_sentences
10
+
11
+ OWN_MODEL_NAME = 'add-a-model'
12
+ PICK_YOUR_OWN_LABEL = 'pick-your-own'
13
+
14
+ MODEL_NAME_DICT = {
15
+ "roberta-large": "RoBERTa-large",
16
+ "bert-large-uncased": "BERT-large",
17
+ "roberta-base": "RoBERTa-base",
18
+ "bert-base-uncased": "BERT-base",
19
+ OWN_MODEL_NAME: "Your model's"
20
+ }
21
+ MODEL_NAMES = list(MODEL_NAME_DICT.keys())
22
+
23
+
24
+ DECIMAL_PLACES = 1
25
+ EPS = 1e-5 # to avoid /0 errors
26
+ NUM_PTS_TO_AVERAGE = 2
27
+
28
+ # Example date conts
29
+ DATE_SPLIT_KEY = "DATE"
30
+ START_YEAR = 1901
31
+ STOP_YEAR = 2016
32
+ NUM_PTS = 30
33
+ DATES = np.linspace(START_YEAR, STOP_YEAR, NUM_PTS).astype(int).tolist()
34
+ DATES = [f'{d}' for d in DATES]
35
+
36
+ GENDERED_LIST = [
37
+ ['he', 'she'],
38
+ ['him', 'her'],
39
+ ['his', 'hers'],
40
+ ["himself", "herself"],
41
+ ['male', 'female'],
42
+ # ['man', 'woman'] Explicitly added in winogender extended sentences
43
+ ['men', 'women'],
44
+ ["husband", "wife"],
45
+ ['father', 'mother'],
46
+ ['boyfriend', 'girlfriend'],
47
+ ['brother', 'sister'],
48
+ ["actor", "actress"],
49
+ ]
50
+
51
+
52
+ # %%
53
+ # Fire up the models
54
+ models = {m : pipeline("fill-mask", model=m) for m in MODEL_NAMES if m != OWN_MODEL_NAME}
55
+
56
+ # %%
57
+ # Get the winogender sentences
58
+ winogender_sentences = get_sentences()
59
+ occs = sorted(list({sentence_id.split('_')[0]
60
+ for sentence_id in winogender_sentences}))
61
+
62
+ # %%
63
+ def get_gendered_token_ids():
64
+ male_gendered_tokens = [list[0] for list in GENDERED_LIST]
65
+ female_gendered_tokens = [list[1] for list in GENDERED_LIST]
66
+
67
+ return male_gendered_tokens, female_gendered_tokens
68
+
69
+
70
+ def get_winogender_texts(occ):
71
+ return [winogender_sentences[id] for id in winogender_sentences.keys() if id.split('_')[0] == occ]
72
+
73
+
74
+ def display_input_texts(occ, alt_text):
75
+ if occ == PICK_YOUR_OWN_LABEL:
76
+ texts = alt_text.split('\n')
77
+ else:
78
+ texts = get_winogender_texts(occ)
79
+
80
+ display_texts = [
81
+ f"{i+1}) {text}" for (i, text) in enumerate(texts)]
82
+ return "\n".join(display_texts), texts
83
+
84
+
85
+ def get_avg_prob_from_pipeline_outputs(pipeline_preds, gendered_tokens, num_preds):
86
+ pronoun_preds = [sum([
87
+ pronoun["score"] if pronoun["token_str"].strip(
88
+ ).lower() in gendered_tokens else 0.0
89
+ for pronoun in top_preds])
90
+ for top_preds in pipeline_preds
91
+ ]
92
+ return round(sum(pronoun_preds) / (EPS + num_preds) * 100, DECIMAL_PLACES)
93
+
94
+
95
+ def is_top_pred_gendered(pipeline_preds, gendered_tokens):
96
+ return pipeline_preds[0][0]['token_str'].strip().lower() in gendered_tokens
97
+
98
+ # %%
99
+
100
+
101
+ def get_figure(df, model_name, occ):
102
+ xs = df[df.columns[0]]
103
+ ys = df[df.columns[1]]
104
+
105
+ fig, ax = plt.subplots()
106
+ ax.bar(xs, ys)
107
+ ax.axis('tight')
108
+ ax.set_xlabel("Sentence number")
109
+ ax.set_ylabel("Uncertainty metric")
110
+ ax.set_title(f"{MODEL_NAME_DICT[model_name]} gender pronoun uncertainty in '{occ}' sentences")
111
+ return fig
112
+
113
+
114
+ # %%
115
+ def predict_gender_pronouns(
116
+ model_name,
117
+ own_model_name,
118
+ texts,
119
+ occ,
120
+ ):
121
+ """Run inference on input_text for selected model type, returning uncertainty results.
122
+ """
123
+
124
+ # TODO: make these selectable by user
125
+ indie_vars = ', '.join(DATES)
126
+ num_ave = NUM_PTS_TO_AVERAGE
127
+
128
+ # For debugging
129
+ print('input_texts', texts)
130
+
131
+ if model_name is None or model_name == '':
132
+ model_name = MODEL_NAMES[0]
133
+ model = models[model_name]
134
+ elif model_name == OWN_MODEL_NAME:
135
+ model = pipeline("fill-mask", model=own_model_name)
136
+ else:
137
+ model = models[model_name]
138
+
139
+ mask_token = model.tokenizer.mask_token
140
+
141
+ indie_vars_list = indie_vars.split(',')
142
+
143
+ male_gendered_tokens, female_gendered_tokens = get_gendered_token_ids()
144
+
145
+ masked_texts = [text.replace('MASK', mask_token) for text in texts]
146
+
147
+ all_uncertainty_f = {}
148
+ not_top_gendered = set()
149
+
150
+ for i, text in enumerate(masked_texts):
151
+ female_pronoun_preds = []
152
+ male_pronoun_preds = []
153
+ top_pred_gendered = True # Assume true unless told otherwise
154
+ print(f"{i+1}) {text}")
155
+ for indie_var in indie_vars_list[:num_ave] + indie_vars_list[-num_ave:]:
156
+
157
+ target_text = f"In {indie_var}: {text}"
158
+
159
+ pipeline_preds = model(target_text)
160
+ # Quick hack as realized return type based on how many MASKs in text.
161
+ if type(pipeline_preds[0]) is not list:
162
+ pipeline_preds = [pipeline_preds]
163
+
164
+ # If top-pred not gendered, record as such
165
+ if not is_top_pred_gendered(pipeline_preds, female_gendered_tokens + male_gendered_tokens):
166
+ top_pred_gendered = False
167
+
168
+ num_preds = 1 # By design
169
+ female_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
170
+ pipeline_preds,
171
+ female_gendered_tokens,
172
+ num_preds
173
+ ))
174
+ male_pronoun_preds.append(get_avg_prob_from_pipeline_outputs(
175
+ pipeline_preds,
176
+ male_gendered_tokens,
177
+ num_preds
178
+ ))
179
+
180
+ # Normalizing by all gendered predictions
181
+ total_gendered_probs = np.add(
182
+ female_pronoun_preds, male_pronoun_preds)
183
+
184
+ norm_female_pronoun_preds = np.around(
185
+ np.divide(female_pronoun_preds, total_gendered_probs+EPS)*100,
186
+ decimals=DECIMAL_PLACES
187
+ )
188
+ sent_idx = f"{i+1}" if top_pred_gendered else f"{i+1}*"
189
+ all_uncertainty_f[sent_idx] = round(abs((sum(norm_female_pronoun_preds[-num_ave:]) - sum(norm_female_pronoun_preds[:num_ave]))
190
+ / num_ave), DECIMAL_PLACES)
191
+
192
+ uncertain_df = pd.DataFrame.from_dict(
193
+ all_uncertainty_f, orient='index', columns=['Uncertainty metric'])
194
+
195
+ uncertain_df = uncertain_df.reset_index().rename(
196
+ columns={'index': 'Sentence number'})
197
+
198
+ return (
199
+ target_text,
200
+ uncertain_df,
201
+ get_figure(uncertain_df, model_name, occ),
202
+ )
203
+
204
+
205
+ demo = gr.Blocks()
206
+ with demo:
207
+ input_texts = gr.Variable([])
208
+ gr.Markdown("## Are you certain?")
209
+ gr.Markdown(
210
+ "#### LLMs are pretty good at reporting their uncertainty. We just need to ask the right way.")
211
+ gr.Markdown("Using our uncertainty metric informed by applying causal inference techniques in \
212
+ [Our ICLR paper under review](https://openreview.net/pdf?id=25VgHaPz0l4), \
213
+ we are able to identify likely spurious correlations and exploit them in \
214
+ the scenario of gender underspecified tasks. (Note that introspecting softmax probabilities alone is insufficient, as in the sentences \
215
+ below, LLMs may report a softmax prob of ~0.9 despite the task being underspecified.)")
216
+
217
+ gr.Markdown("We extend the [Winogender Schemas](https://github.com/rudinger/winogender-schemas) evaluation set to produce\
218
+ eight syntactically similar sentences. However semantically, \
219
+ only two of the sentences are gender-specified while the rest remain gender-underspecified")
220
+ gr.Markdown("If a model can reliably tell us when it is uncertain about its predictions, one can replace only those uncertain predictions with\
221
+ an appropriate heuristic.")
222
+
223
+ with gr.Row():
224
+ model_name = gr.Radio(
225
+ MODEL_NAMES,
226
+ type="value",
227
+ label="Pick a preloaded BERT-like model for uncertainty evaluation (note: BERT-base performance least consistent)...",
228
+ )
229
+ own_model_name = gr.Textbox(
230
+ label=f"...Or, if you selected an '{OWN_MODEL_NAME}' model, put any Hugging Face pipeline model name \
231
+ (that supports the `fill-mask` task (see list at https://huggingface.co/models?pipeline_tag=fill-mask).",
232
+ )
233
+
234
+ with gr.Row():
235
+ occ_box = gr.Radio(
236
+ occs+[PICK_YOUR_OWN_LABEL], label=f"Pick an Occupation type from the Winogender Schemas evaluation set, or select '{PICK_YOUR_OWN_LABEL}'\
237
+ (it need not be about an occupation).")
238
+
239
+ with gr.Row():
240
+ alt_input_texts = gr.Textbox(
241
+ lines=2,
242
+ label=f"...If you selected '{PICK_YOUR_OWN_LABEL}' above, add your own texts new-line delimited sentences here. Be sure\
243
+ to include a single MASK-ed out pronoun. \
244
+ If unsure on the required format, click an occupation above instead, to see some example input texts for this round.",
245
+ )
246
+
247
+ with gr.Row():
248
+ get_text_btn = gr.Button("Load input texts")
249
+
250
+ get_text_btn.click(
251
+ fn=display_input_texts,
252
+ inputs=[occ_box, alt_input_texts],
253
+ outputs=[gr.Textbox(
254
+ label='Numbered sentences for evaluation. Number below corresponds to number in x-axis of plot.'), input_texts],
255
+
256
+ )
257
+
258
+ with gr.Row():
259
+ uncertain_btn = gr.Button("Get uncertainty results!")
260
+ gr.Markdown(
261
+ "If there is an * by a sentence number, then at least one top prediction for that sentence was non-gendered.")
262
+
263
+ with gr.Row():
264
+ female_fig = gr.Plot(type="auto")
265
+ with gr.Row():
266
+ female_df = gr.Dataframe()
267
+ with gr.Row():
268
+ display_text = gr.Textbox(
269
+ type="auto", label="Sample of text fed to model")
270
+
271
+ uncertain_btn.click(
272
+ fn=predict_gender_pronouns,
273
+ inputs=[model_name, own_model_name, input_texts, occ_box],
274
+ # inputs=date_example,
275
+ outputs=[display_text, female_df, female_fig]
276
+ )
277
+
278
+ demo.launch(debug=True)
279
+
280
+ # %%