maximuspowers's picture
Create app.py
d77f6b0 verified
import numpy as np
import plotly.graph_objects as go
import json
import gradio as gr
from nltk.corpus import words
import nltk
# load files w embeddings, attention scores, and tokens
vocab_embeddings = np.load('vocab_embeddings.npy')
with open('vocab_attention_scores.json', 'r') as f:
vocab_attention_scores = json.load(f)
with open('vocab_tokens.json', 'r') as f:
vocab_tokens = json.load(f)
# attention scores to numpy arrs
b_gen_attention = np.array([score['B-GEN'] for score in vocab_attention_scores])
i_gen_attention = np.array([score['I-GEN'] for score in vocab_attention_scores])
b_unfair_attention = np.array([score['B-UNFAIR'] for score in vocab_attention_scores])
i_unfair_attention = np.array([score['I-UNFAIR'] for score in vocab_attention_scores])
b_stereo_attention = np.array([score['B-STEREO'] for score in vocab_attention_scores])
i_stereo_attention = np.array([score['I-STEREO'] for score in vocab_attention_scores])
o_attention = np.array([score['O'] for score in vocab_attention_scores]) # Use actual O scores
# remove non-dict english words, but keep subwords ##
nltk.download('words')
english_words = set(words.words())
filtered_indices = [i for i, token in enumerate(vocab_tokens) if token in english_words or token.startswith("##")]
filtered_tokens = [vocab_tokens[i] for i in filtered_indices]
b_gen_attention_filtered = b_gen_attention[filtered_indices]
i_gen_attention_filtered = i_gen_attention[filtered_indices]
b_unfair_attention_filtered = b_unfair_attention[filtered_indices]
i_unfair_attention_filtered = i_unfair_attention[filtered_indices]
b_stereo_attention_filtered = b_stereo_attention[filtered_indices]
i_stereo_attention_filtered = i_stereo_attention[filtered_indices]
o_attention_filtered = o_attention[filtered_indices]
# plot top 500 O tokens for comparison
top_500_o_indices = np.argsort(o_attention_filtered)[-500:]
top_500_o_tokens = [filtered_tokens[i] for i in top_500_o_indices]
o_attention_filtered_top_500 = o_attention_filtered[top_500_o_indices]
# tool tip for tokens
def create_hover_text(tokens, b_gen, i_gen, b_unfair, i_unfair, b_stereo, i_stereo, o_val):
hover_text = []
for i in range(len(tokens)):
hover_text.append(
f"Token: {tokens[i]}<br>"
f"B-GEN: {b_gen[i]:.3f}, I-GEN: {i_gen[i]:.3f}<br>"
f"B-UNFAIR: {b_unfair[i]:.3f}, I-UNFAIR: {i_unfair[i]:.3f}<br>"
f"B-STEREO: {b_stereo[i]:.3f}, I-STEREO: {i_stereo[i]:.3f}<br>"
f"O: {o_val[i]:.3f}"
)
return hover_text
# ploting top 100 tokens for each entity
def select_top_100(*data_arrays):
indices_list = []
for data in data_arrays:
if data is not None:
top_indices = np.argsort(data)[-100:]
indices_list.append(top_indices)
combined_indices = np.unique(np.concatenate(indices_list))
# filter based on combined indices
filtered_data = [data[combined_indices] if data is not None else None for data in data_arrays]
tokens_filtered = [filtered_tokens[i] for i in combined_indices]
return (*filtered_data, tokens_filtered)
# plots for 1 2 and 3 D
def create_plot(selected_dimensions):
# plot data
attention_map = {
'Generalization': b_gen_attention_filtered + i_gen_attention_filtered,
'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered,
'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered,
}
# init x, y, z so they can be moved around
x_data, y_data, z_data = None, None, None
# use selected dimentsions to order dimensions
if len(selected_dimensions) > 0:
x_data = attention_map[selected_dimensions[0]]
if len(selected_dimensions) > 1:
y_data = attention_map[selected_dimensions[1]]
if len(selected_dimensions) > 2:
z_data = attention_map[selected_dimensions[2]]
# select top 100 dps for each selected dimension
x_data, y_data, z_data, tokens_filtered = select_top_100(x_data, y_data, z_data)
# filter the O tokens using the same dimensions
o_x = attention_map[selected_dimensions[0]][top_500_o_indices]
if len(selected_dimensions) > 1:
o_y = attention_map[selected_dimensions[1]][top_500_o_indices]
else:
o_y = np.zeros_like(o_x)
if len(selected_dimensions) > 2:
o_z = attention_map[selected_dimensions[2]][top_500_o_indices]
else:
o_z = np.zeros_like(o_x)
# hover text for GUS tokens
classified_hover_text = create_hover_text(
tokens_filtered,
b_gen_attention_filtered, i_gen_attention_filtered,
b_unfair_attention_filtered, i_unfair_attention_filtered,
b_stereo_attention_filtered, i_stereo_attention_filtered,
o_attention_filtered
)
# hover text for O tokens
o_hover_text = create_hover_text(
top_500_o_tokens,
b_gen_attention_filtered[top_500_o_indices], i_gen_attention_filtered[top_500_o_indices],
b_unfair_attention_filtered[top_500_o_indices], i_unfair_attention_filtered[top_500_o_indices],
b_stereo_attention_filtered[top_500_o_indices], i_stereo_attention_filtered[top_500_o_indices],
o_attention_filtered_top_500
)
# plot
fig = go.Figure()
if x_data is not None and y_data is not None and z_data is not None:
# 3d scatter plot
fig.add_trace(go.Scatter3d(
x=x_data,
y=y_data,
z=z_data,
mode='markers',
marker=dict(
size=6,
color=x_data, # color based on the x-axis data
colorscale='Viridis',
opacity=0.85,
),
text=classified_hover_text,
hoverinfo='text',
name='Classified Tokens'
))
# add top 500 O tags to the plot too
fig.add_trace(go.Scatter3d(
x=o_x,
y=o_y,
z=o_z,
mode='markers',
marker=dict(
size=6,
color='grey',
opacity=0.5,
),
text=o_hover_text,
hoverinfo='text',
name='O Tokens'
))
elif x_data is not None and y_data is not None:
# 2d scatter plot
fig.add_trace(go.Scatter(
x=x_data,
y=y_data,
mode='markers',
marker=dict(
size=6,
color=x_data, # color based on the x-axis data
colorscale='Viridis',
opacity=0.85,
),
text=classified_hover_text,
hoverinfo='text',
name='Classified Tokens'
))
# add top 500 O tags to the plot too
fig.add_trace(go.Scatter(
x=o_x,
y=o_y,
mode='markers',
marker=dict(
size=6,
color='grey',
opacity=0.5,
),
text=o_hover_text,
hoverinfo='text',
name='O Tokens'
))
elif x_data is not None:
# 1D scatter plot
fig.add_trace(go.Scatter(
x=x_data,
y=np.zeros_like(x_data),
mode='markers',
marker=dict(
size=6,
color=x_data,
colorscale='Viridis',
opacity=0.85,
),
text=classified_hover_text,
hoverinfo='text',
name='GUS Tokens'
))
fig.add_trace(go.Scatter(
x=o_x,
y=np.zeros_like(o_x),
mode='markers',
marker=dict(
size=6,
color='grey',
opacity=0.5,
),
text=o_hover_text,
hoverinfo='text',
name='O Tokens'
))
# update layout dynamically
if x_data is not None and y_data is not None and z_data is not None:
# 3D
fig.update_layout(
title="GUS-Net Entity Attentions Visualization",
scene=dict(
xaxis=dict(title=f"{selected_dimensions[0]} Attention"),
yaxis=dict(title=f"{selected_dimensions[1]} Attention"),
zaxis=dict(title=f"{selected_dimensions[2]} Attention"),
),
margin=dict(l=0, r=0, b=0, t=40),
)
elif x_data is not None and y_data is not None:
# 2D
fig.update_layout(
title="GUS-Net Entity Attentions Visualization",
xaxis_title=f"{selected_dimensions[0]} Attention",
yaxis_title=f"{selected_dimensions[1]} Attention",
margin=dict(l=0, r=0, b=0, t=40),
)
elif x_data is not None:
# 1D
fig.update_layout(
title="GUS-Net Entity Attentions Visualization",
xaxis_title=f"{selected_dimensions[0]} Attention",
margin=dict(l=0, r=0, b=0, t=40),
)
return fig
def get_top_tokens_for_entities(selected_dimensions):
entity_map = {
'Generalization': b_gen_attention_filtered + i_gen_attention_filtered,
'Unfairness': b_unfair_attention_filtered + i_unfair_attention_filtered,
'Stereotype': b_stereo_attention_filtered + i_stereo_attention_filtered,
}
top_tokens_info = {}
for dimension in selected_dimensions:
if dimension in entity_map:
attention_scores = entity_map[dimension]
top_indices = np.argsort(attention_scores)[-10:] # top 10 tokens
top_tokens = [filtered_tokens[i] for i in top_indices]
top_scores = attention_scores[top_indices]
top_tokens_info[dimension] = list(zip(top_tokens, top_scores))
return top_tokens_info
def update_gradio(selected_dimensions):
fig = create_plot(selected_dimensions)
top_tokens_info = get_top_tokens_for_entities(selected_dimensions)
formatted_top_tokens = ""
for entity, tokens_info in top_tokens_info.items():
formatted_top_tokens += f"\nTop tokens for {entity}:\n"
for token, score in tokens_info:
formatted_top_tokens += f"Token: {token}, Attention Score: {score:.3f}\n"
return fig, formatted_top_tokens
def render_gradio_interface():
with gr.Blocks() as interface:
with gr.Column():
dimensions_input = gr.CheckboxGroup(
choices=["Generalization", "Unfairness", "Stereotype"],
label="Select Dimensions to Plot",
value=["Generalization", "Unfairness", "Stereotype"] # defaults to 3D
)
plot_output = gr.Plot(label="Token Attention Visualization")
top_tokens_output = gr.Textbox(label="Top Tokens for Each Entity Class", lines=10)
dimensions_input.change(
fn=update_gradio,
inputs=[dimensions_input],
outputs=[plot_output, top_tokens_output]
)
interface.load(
fn=lambda: update_gradio(["Generalization", "Unfairness", "Stereotype"]),
inputs=None,
outputs=[plot_output, top_tokens_output]
)
return interface
interface = render_gradio_interface()
interface.launch()