maximuspowers commited on
Commit
d77f6b0
·
verified ·
1 Parent(s): d00edf9

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +311 -0
app.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import plotly.graph_objects as go
3
+ import json
4
+ import gradio as gr
5
+ from nltk.corpus import words
6
+ import nltk
7
+
8
+
9
+ # load files w embeddings, attention scores, and tokens
10
+ vocab_embeddings = np.load('vocab_embeddings.npy')
11
+ with open('vocab_attention_scores.json', 'r') as f:
12
+ vocab_attention_scores = json.load(f)
13
+ with open('vocab_tokens.json', 'r') as f:
14
+ vocab_tokens = json.load(f)
15
+
16
+ # attention scores to numpy arrs
17
+ b_gen_attention = np.array([score['B-GEN'] for score in vocab_attention_scores])
18
+ i_gen_attention = np.array([score['I-GEN'] for score in vocab_attention_scores])
19
+ b_unfair_attention = np.array([score['B-UNFAIR'] for score in vocab_attention_scores])
20
+ i_unfair_attention = np.array([score['I-UNFAIR'] for score in vocab_attention_scores])
21
+ b_stereo_attention = np.array([score['B-STEREO'] for score in vocab_attention_scores])
22
+ i_stereo_attention = np.array([score['I-STEREO'] for score in vocab_attention_scores])
23
+ o_attention = np.array([score['O'] for score in vocab_attention_scores]) # Use actual O scores
24
+
25
+ # remove non-dict english words, but keep subwords ##
26
+ nltk.download('words')
27
+ english_words = set(words.words())
28
+
29
+ filtered_indices = [i for i, token in enumerate(vocab_tokens) if token in english_words or token.startswith("##")]
30
+ filtered_tokens = [vocab_tokens[i] for i in filtered_indices]
31
+
32
+ b_gen_attention_filtered = b_gen_attention[filtered_indices]
33
+ i_gen_attention_filtered = i_gen_attention[filtered_indices]
34
+ b_unfair_attention_filtered = b_unfair_attention[filtered_indices]
35
+ i_unfair_attention_filtered = i_unfair_attention[filtered_indices]
36
+ b_stereo_attention_filtered = b_stereo_attention[filtered_indices]
37
+ i_stereo_attention_filtered = i_stereo_attention[filtered_indices]
38
+ o_attention_filtered = o_attention[filtered_indices]
39
+
40
+ # plot top 500 O tokens for comparison
41
+ top_500_o_indices = np.argsort(o_attention_filtered)[-500:]
42
+ top_500_o_tokens = [filtered_tokens[i] for i in top_500_o_indices]
43
+ o_attention_filtered_top_500 = o_attention_filtered[top_500_o_indices]
44
+
45
+ # tool tip for tokens
46
+ def create_hover_text(tokens, b_gen, i_gen, b_unfair, i_unfair, b_stereo, i_stereo, o_val):
47
+ hover_text = []
48
+ for i in range(len(tokens)):
49
+ hover_text.append(
50
+ f"Token: {tokens[i]}<br>"
51
+ f"B-GEN: {b_gen[i]:.3f}, I-GEN: {i_gen[i]:.3f}<br>"
52
+ f"B-UNFAIR: {b_unfair[i]:.3f}, I-UNFAIR: {i_unfair[i]:.3f}<br>"
53
+ f"B-STEREO: {b_stereo[i]:.3f}, I-STEREO: {i_stereo[i]:.3f}<br>"
54
+ f"O: {o_val[i]:.3f}"
55
+ )
56
+ return hover_text
57
+
58
+ # ploting top 100 tokens for each entity
59
+ def select_top_100(*data_arrays):
60
+ indices_list = []
61
+ for data in data_arrays:
62
+ if data is not None:
63
+ top_indices = np.argsort(data)[-100:]
64
+ indices_list.append(top_indices)
65
+
66
+ combined_indices = np.unique(np.concatenate(indices_list))
67
+
68
+ # filter based on combined indices
69
+ filtered_data = [data[combined_indices] if data is not None else None for data in data_arrays]
70
+ tokens_filtered = [filtered_tokens[i] for i in combined_indices]
71
+
72
+ return (*filtered_data, tokens_filtered)
73
+
74
+ # plots for 1 2 and 3 D
75
+ def create_plot(selected_dimensions):
76
+ # plot data
77
+ attention_map = {
78
+ 'Generalization': b_gen_attention_filtered + i_gen_attention_filtered,
79
+ 'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered,
80
+ 'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered,
81
+ }
82
+
83
+ # init x, y, z so they can be moved around
84
+ x_data, y_data, z_data = None, None, None
85
+
86
+ # use selected dimentsions to order dimensions
87
+ if len(selected_dimensions) > 0:
88
+ x_data = attention_map[selected_dimensions[0]]
89
+ if len(selected_dimensions) > 1:
90
+ y_data = attention_map[selected_dimensions[1]]
91
+ if len(selected_dimensions) > 2:
92
+ z_data = attention_map[selected_dimensions[2]]
93
+
94
+ # select top 100 dps for each selected dimension
95
+ x_data, y_data, z_data, tokens_filtered = select_top_100(x_data, y_data, z_data)
96
+
97
+ # filter the O tokens using the same dimensions
98
+ o_x = attention_map[selected_dimensions[0]][top_500_o_indices]
99
+ if len(selected_dimensions) > 1:
100
+ o_y = attention_map[selected_dimensions[1]][top_500_o_indices]
101
+ else:
102
+ o_y = np.zeros_like(o_x)
103
+ if len(selected_dimensions) > 2:
104
+ o_z = attention_map[selected_dimensions[2]][top_500_o_indices]
105
+ else:
106
+ o_z = np.zeros_like(o_x)
107
+
108
+ # hover text for GUS tokens
109
+ classified_hover_text = create_hover_text(
110
+ tokens_filtered,
111
+ b_gen_attention_filtered, i_gen_attention_filtered,
112
+ b_unfair_attention_filtered, i_unfair_attention_filtered,
113
+ b_stereo_attention_filtered, i_stereo_attention_filtered,
114
+ o_attention_filtered
115
+ )
116
+
117
+ # hover text for O tokens
118
+ o_hover_text = create_hover_text(
119
+ top_500_o_tokens,
120
+ b_gen_attention_filtered[top_500_o_indices], i_gen_attention_filtered[top_500_o_indices],
121
+ b_unfair_attention_filtered[top_500_o_indices], i_unfair_attention_filtered[top_500_o_indices],
122
+ b_stereo_attention_filtered[top_500_o_indices], i_stereo_attention_filtered[top_500_o_indices],
123
+ o_attention_filtered_top_500
124
+ )
125
+
126
+
127
+ # plot
128
+ fig = go.Figure()
129
+
130
+ if x_data is not None and y_data is not None and z_data is not None:
131
+ # 3d scatter plot
132
+ fig.add_trace(go.Scatter3d(
133
+ x=x_data,
134
+ y=y_data,
135
+ z=z_data,
136
+ mode='markers',
137
+ marker=dict(
138
+ size=6,
139
+ color=x_data, # color based on the x-axis data
140
+ colorscale='Viridis',
141
+ opacity=0.85,
142
+ ),
143
+ text=classified_hover_text,
144
+ hoverinfo='text',
145
+ name='Classified Tokens'
146
+ ))
147
+ # add top 500 O tags to the plot too
148
+ fig.add_trace(go.Scatter3d(
149
+ x=o_x,
150
+ y=o_y,
151
+ z=o_z,
152
+ mode='markers',
153
+ marker=dict(
154
+ size=6,
155
+ color='grey',
156
+ opacity=0.5,
157
+ ),
158
+ text=o_hover_text,
159
+ hoverinfo='text',
160
+ name='O Tokens'
161
+ ))
162
+ elif x_data is not None and y_data is not None:
163
+ # 2d scatter plot
164
+ fig.add_trace(go.Scatter(
165
+ x=x_data,
166
+ y=y_data,
167
+ mode='markers',
168
+ marker=dict(
169
+ size=6,
170
+ color=x_data, # color based on the x-axis data
171
+ colorscale='Viridis',
172
+ opacity=0.85,
173
+ ),
174
+ text=classified_hover_text,
175
+ hoverinfo='text',
176
+ name='Classified Tokens'
177
+ ))
178
+ # add top 500 O tags to the plot too
179
+ fig.add_trace(go.Scatter(
180
+ x=o_x,
181
+ y=o_y,
182
+ mode='markers',
183
+ marker=dict(
184
+ size=6,
185
+ color='grey',
186
+ opacity=0.5,
187
+ ),
188
+ text=o_hover_text,
189
+ hoverinfo='text',
190
+ name='O Tokens'
191
+ ))
192
+ elif x_data is not None:
193
+ # 1D scatter plot
194
+ fig.add_trace(go.Scatter(
195
+ x=x_data,
196
+ y=np.zeros_like(x_data),
197
+ mode='markers',
198
+ marker=dict(
199
+ size=6,
200
+ color=x_data,
201
+ colorscale='Viridis',
202
+ opacity=0.85,
203
+ ),
204
+ text=classified_hover_text,
205
+ hoverinfo='text',
206
+ name='GUS Tokens'
207
+ ))
208
+ fig.add_trace(go.Scatter(
209
+ x=o_x,
210
+ y=np.zeros_like(o_x),
211
+ mode='markers',
212
+ marker=dict(
213
+ size=6,
214
+ color='grey',
215
+ opacity=0.5,
216
+ ),
217
+ text=o_hover_text,
218
+ hoverinfo='text',
219
+ name='O Tokens'
220
+ ))
221
+
222
+ # update layout dynamically
223
+ if x_data is not None and y_data is not None and z_data is not None:
224
+ # 3D
225
+ fig.update_layout(
226
+ title="GUS-Net Entity Attentions Visualization",
227
+ scene=dict(
228
+ xaxis=dict(title=f"{selected_dimensions[0]} Attention"),
229
+ yaxis=dict(title=f"{selected_dimensions[1]} Attention"),
230
+ zaxis=dict(title=f"{selected_dimensions[2]} Attention"),
231
+ ),
232
+ margin=dict(l=0, r=0, b=0, t=40),
233
+ )
234
+ elif x_data is not None and y_data is not None:
235
+ # 2D
236
+ fig.update_layout(
237
+ title="GUS-Net Entity Attentions Visualization",
238
+ xaxis_title=f"{selected_dimensions[0]} Attention",
239
+ yaxis_title=f"{selected_dimensions[1]} Attention",
240
+ margin=dict(l=0, r=0, b=0, t=40),
241
+ )
242
+ elif x_data is not None:
243
+ # 1D
244
+ fig.update_layout(
245
+ title="GUS-Net Entity Attentions Visualization",
246
+ xaxis_title=f"{selected_dimensions[0]} Attention",
247
+ margin=dict(l=0, r=0, b=0, t=40),
248
+ )
249
+
250
+ return fig
251
+
252
+ def get_top_tokens_for_entities(selected_dimensions):
253
+ entity_map = {
254
+ 'Generalization': b_gen_attention_filtered + i_gen_attention_filtered,
255
+ 'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered,
256
+ 'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered,
257
+ }
258
+
259
+ top_tokens_info = {}
260
+ for dimension in selected_dimensions:
261
+ if dimension in entity_map:
262
+ attention_scores = entity_map[dimension]
263
+ top_indices = np.argsort(attention_scores)[-10:] # top 10 tokens
264
+ top_tokens = [filtered_tokens[i] for i in top_indices]
265
+ top_scores = attention_scores[top_indices]
266
+ top_tokens_info[dimension] = list(zip(top_tokens, top_scores))
267
+
268
+ return top_tokens_info
269
+
270
+ def update_gradio(selected_dimensions):
271
+ fig = create_plot(selected_dimensions)
272
+
273
+ top_tokens_info = get_top_tokens_for_entities(selected_dimensions)
274
+
275
+ formatted_top_tokens = ""
276
+ for entity, tokens_info in top_tokens_info.items():
277
+ formatted_top_tokens += f"\nTop tokens for {entity}:\n"
278
+ for token, score in tokens_info:
279
+ formatted_top_tokens += f"Token: {token}, Attention Score: {score:.3f}\n"
280
+
281
+ return fig, formatted_top_tokens
282
+
283
+
284
+ def render_gradio_interface():
285
+ with gr.Blocks() as interface:
286
+ with gr.Column():
287
+ dimensions_input = gr.CheckboxGroup(
288
+ choices=["Generalization", "Unfairness", "Stereotype"],
289
+ label="Select Dimensions to Plot",
290
+ value=["Generalization", "Unfairness", "Stereotype"] # defaults to 3D
291
+ )
292
+
293
+ plot_output = gr.Plot(label="Token Attention Visualization")
294
+ top_tokens_output = gr.Textbox(label="Top Tokens for Each Entity Class", lines=10)
295
+
296
+ dimensions_input.change(
297
+ fn=update_gradio,
298
+ inputs=[dimensions_input],
299
+ outputs=[plot_output, top_tokens_output]
300
+ )
301
+
302
+ interface.load(
303
+ fn=lambda: update_gradio(["Generalization", "Unfairness", "Stereotype"]),
304
+ inputs=None,
305
+ outputs=[plot_output, top_tokens_output]
306
+ )
307
+
308
+ return interface
309
+
310
+ interface = render_gradio_interface()
311
+ interface.launch()