Spaces:
Sleeping
Sleeping
import numpy as np | |
import plotly.graph_objects as go | |
import json | |
import gradio as gr | |
from nltk.corpus import words | |
import nltk | |
# load files w embeddings, attention scores, and tokens | |
vocab_embeddings = np.load('vocab_embeddings.npy') | |
with open('vocab_attention_scores.json', 'r') as f: | |
vocab_attention_scores = json.load(f) | |
with open('vocab_tokens.json', 'r') as f: | |
vocab_tokens = json.load(f) | |
# attention scores to numpy arrs | |
b_gen_attention = np.array([score['B-GEN'] for score in vocab_attention_scores]) | |
i_gen_attention = np.array([score['I-GEN'] for score in vocab_attention_scores]) | |
b_unfair_attention = np.array([score['B-UNFAIR'] for score in vocab_attention_scores]) | |
i_unfair_attention = np.array([score['I-UNFAIR'] for score in vocab_attention_scores]) | |
b_stereo_attention = np.array([score['B-STEREO'] for score in vocab_attention_scores]) | |
i_stereo_attention = np.array([score['I-STEREO'] for score in vocab_attention_scores]) | |
o_attention = np.array([score['O'] for score in vocab_attention_scores]) # Use actual O scores | |
# remove non-dict english words, but keep subwords ## | |
nltk.download('words') | |
english_words = set(words.words()) | |
filtered_indices = [i for i, token in enumerate(vocab_tokens) if token in english_words or token.startswith("##")] | |
filtered_tokens = [vocab_tokens[i] for i in filtered_indices] | |
b_gen_attention_filtered = b_gen_attention[filtered_indices] | |
i_gen_attention_filtered = i_gen_attention[filtered_indices] | |
b_unfair_attention_filtered = b_unfair_attention[filtered_indices] | |
i_unfair_attention_filtered = i_unfair_attention[filtered_indices] | |
b_stereo_attention_filtered = b_stereo_attention[filtered_indices] | |
i_stereo_attention_filtered = i_stereo_attention[filtered_indices] | |
o_attention_filtered = o_attention[filtered_indices] | |
# plot top 500 O tokens for comparison | |
top_500_o_indices = np.argsort(o_attention_filtered)[-500:] | |
top_500_o_tokens = [filtered_tokens[i] for i in top_500_o_indices] | |
o_attention_filtered_top_500 = o_attention_filtered[top_500_o_indices] | |
# tool tip for tokens | |
def create_hover_text(tokens, b_gen, i_gen, b_unfair, i_unfair, b_stereo, i_stereo, o_val): | |
hover_text = [] | |
for i in range(len(tokens)): | |
hover_text.append( | |
f"Token: {tokens[i]}<br>" | |
f"B-GEN: {b_gen[i]:.3f}, I-GEN: {i_gen[i]:.3f}<br>" | |
f"B-UNFAIR: {b_unfair[i]:.3f}, I-UNFAIR: {i_unfair[i]:.3f}<br>" | |
f"B-STEREO: {b_stereo[i]:.3f}, I-STEREO: {i_stereo[i]:.3f}<br>" | |
f"O: {o_val[i]:.3f}" | |
) | |
return hover_text | |
# ploting top 100 tokens for each entity | |
def select_top_100(*data_arrays): | |
indices_list = [] | |
for data in data_arrays: | |
if data is not None: | |
top_indices = np.argsort(data)[-100:] | |
indices_list.append(top_indices) | |
combined_indices = np.unique(np.concatenate(indices_list)) | |
# filter based on combined indices | |
filtered_data = [data[combined_indices] if data is not None else None for data in data_arrays] | |
tokens_filtered = [filtered_tokens[i] for i in combined_indices] | |
return (*filtered_data, tokens_filtered) | |
# plots for 1 2 and 3 D | |
def create_plot(selected_dimensions): | |
# plot data | |
attention_map = { | |
'Generalization': b_gen_attention_filtered + i_gen_attention_filtered, | |
'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered, | |
'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered, | |
} | |
# init x, y, z so they can be moved around | |
x_data, y_data, z_data = None, None, None | |
# use selected dimentsions to order dimensions | |
if len(selected_dimensions) > 0: | |
x_data = attention_map[selected_dimensions[0]] | |
if len(selected_dimensions) > 1: | |
y_data = attention_map[selected_dimensions[1]] | |
if len(selected_dimensions) > 2: | |
z_data = attention_map[selected_dimensions[2]] | |
# select top 100 dps for each selected dimension | |
x_data, y_data, z_data, tokens_filtered = select_top_100(x_data, y_data, z_data) | |
# filter the O tokens using the same dimensions | |
o_x = attention_map[selected_dimensions[0]][top_500_o_indices] | |
if len(selected_dimensions) > 1: | |
o_y = attention_map[selected_dimensions[1]][top_500_o_indices] | |
else: | |
o_y = np.zeros_like(o_x) | |
if len(selected_dimensions) > 2: | |
o_z = attention_map[selected_dimensions[2]][top_500_o_indices] | |
else: | |
o_z = np.zeros_like(o_x) | |
# hover text for GUS tokens | |
classified_hover_text = create_hover_text( | |
tokens_filtered, | |
b_gen_attention_filtered, i_gen_attention_filtered, | |
b_unfair_attention_filtered, i_unfair_attention_filtered, | |
b_stereo_attention_filtered, i_stereo_attention_filtered, | |
o_attention_filtered | |
) | |
# hover text for O tokens | |
o_hover_text = create_hover_text( | |
top_500_o_tokens, | |
b_gen_attention_filtered[top_500_o_indices], i_gen_attention_filtered[top_500_o_indices], | |
b_unfair_attention_filtered[top_500_o_indices], i_unfair_attention_filtered[top_500_o_indices], | |
b_stereo_attention_filtered[top_500_o_indices], i_stereo_attention_filtered[top_500_o_indices], | |
o_attention_filtered_top_500 | |
) | |
# plot | |
fig = go.Figure() | |
if x_data is not None and y_data is not None and z_data is not None: | |
# 3d scatter plot | |
fig.add_trace(go.Scatter3d( | |
x=x_data, | |
y=y_data, | |
z=z_data, | |
mode='markers', | |
marker=dict( | |
size=6, | |
color=x_data, # color based on the x-axis data | |
colorscale='Viridis', | |
opacity=0.85, | |
), | |
text=classified_hover_text, | |
hoverinfo='text', | |
name='Classified Tokens' | |
)) | |
# add top 500 O tags to the plot too | |
fig.add_trace(go.Scatter3d( | |
x=o_x, | |
y=o_y, | |
z=o_z, | |
mode='markers', | |
marker=dict( | |
size=6, | |
color='grey', | |
opacity=0.5, | |
), | |
text=o_hover_text, | |
hoverinfo='text', | |
name='O Tokens' | |
)) | |
elif x_data is not None and y_data is not None: | |
# 2d scatter plot | |
fig.add_trace(go.Scatter( | |
x=x_data, | |
y=y_data, | |
mode='markers', | |
marker=dict( | |
size=6, | |
color=x_data, # color based on the x-axis data | |
colorscale='Viridis', | |
opacity=0.85, | |
), | |
text=classified_hover_text, | |
hoverinfo='text', | |
name='Classified Tokens' | |
)) | |
# add top 500 O tags to the plot too | |
fig.add_trace(go.Scatter( | |
x=o_x, | |
y=o_y, | |
mode='markers', | |
marker=dict( | |
size=6, | |
color='grey', | |
opacity=0.5, | |
), | |
text=o_hover_text, | |
hoverinfo='text', | |
name='O Tokens' | |
)) | |
elif x_data is not None: | |
# 1D scatter plot | |
fig.add_trace(go.Scatter( | |
x=x_data, | |
y=np.zeros_like(x_data), | |
mode='markers', | |
marker=dict( | |
size=6, | |
color=x_data, | |
colorscale='Viridis', | |
opacity=0.85, | |
), | |
text=classified_hover_text, | |
hoverinfo='text', | |
name='GUS Tokens' | |
)) | |
fig.add_trace(go.Scatter( | |
x=o_x, | |
y=np.zeros_like(o_x), | |
mode='markers', | |
marker=dict( | |
size=6, | |
color='grey', | |
opacity=0.5, | |
), | |
text=o_hover_text, | |
hoverinfo='text', | |
name='O Tokens' | |
)) | |
# update layout dynamically | |
if x_data is not None and y_data is not None and z_data is not None: | |
# 3D | |
fig.update_layout( | |
title="GUS-Net Entity Attentions Visualization", | |
scene=dict( | |
xaxis=dict(title=f"{selected_dimensions[0]} Attention"), | |
yaxis=dict(title=f"{selected_dimensions[1]} Attention"), | |
zaxis=dict(title=f"{selected_dimensions[2]} Attention"), | |
), | |
margin=dict(l=0, r=0, b=0, t=40), | |
) | |
elif x_data is not None and y_data is not None: | |
# 2D | |
fig.update_layout( | |
title="GUS-Net Entity Attentions Visualization", | |
xaxis_title=f"{selected_dimensions[0]} Attention", | |
yaxis_title=f"{selected_dimensions[1]} Attention", | |
margin=dict(l=0, r=0, b=0, t=40), | |
) | |
elif x_data is not None: | |
# 1D | |
fig.update_layout( | |
title="GUS-Net Entity Attentions Visualization", | |
xaxis_title=f"{selected_dimensions[0]} Attention", | |
margin=dict(l=0, r=0, b=0, t=40), | |
) | |
return fig | |
def get_top_tokens_for_entities(selected_dimensions): | |
entity_map = { | |
'Generalization': b_gen_attention_filtered + i_gen_attention_filtered, | |
'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered, | |
'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered, | |
} | |
top_tokens_info = {} | |
for dimension in selected_dimensions: | |
if dimension in entity_map: | |
attention_scores = entity_map[dimension] | |
top_indices = np.argsort(attention_scores)[-10:] # top 10 tokens | |
top_tokens = [filtered_tokens[i] for i in top_indices] | |
top_scores = attention_scores[top_indices] | |
top_tokens_info[dimension] = list(zip(top_tokens, top_scores)) | |
return top_tokens_info | |
def update_gradio(selected_dimensions): | |
fig = create_plot(selected_dimensions) | |
top_tokens_info = get_top_tokens_for_entities(selected_dimensions) | |
formatted_top_tokens = "" | |
for entity, tokens_info in top_tokens_info.items(): | |
formatted_top_tokens += f"\nTop tokens for {entity}:\n" | |
for token, score in tokens_info: | |
formatted_top_tokens += f"Token: {token}, Attention Score: {score:.3f}\n" | |
return fig, formatted_top_tokens | |
def render_gradio_interface(): | |
with gr.Blocks() as interface: | |
with gr.Column(): | |
dimensions_input = gr.CheckboxGroup( | |
choices=["Generalization", "Unfairness", "Stereotype"], | |
label="Select Dimensions to Plot", | |
value=["Generalization", "Unfairness", "Stereotype"] # defaults to 3D | |
) | |
plot_output = gr.Plot(label="Token Attention Visualization") | |
top_tokens_output = gr.Textbox(label="Top Tokens for Each Entity Class", lines=10) | |
dimensions_input.change( | |
fn=update_gradio, | |
inputs=[dimensions_input], | |
outputs=[plot_output, top_tokens_output] | |
) | |
interface.load( | |
fn=lambda: update_gradio(["Generalization", "Unfairness", "Stereotype"]), | |
inputs=None, | |
outputs=[plot_output, top_tokens_output] | |
) | |
return interface | |
interface = render_gradio_interface() | |
interface.launch() |