Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import plotly.express as px | |
model_name = 'Qwen/Qwen2-1.5B' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def load_model(): | |
return AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
token=st.secrets['hf_token'] | |
).to(device) | |
def load_tokenizer(): | |
return AutoTokenizer.from_pretrained( | |
model_name, | |
token=st.secrets['hf_token'] | |
) | |
def get_attention_weights_and_tokens(text): | |
tokenized = tokenizer(text, return_tensors='pt') | |
tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]] | |
tokenized = tokenized.to(device) | |
output = model(**tokenized, output_attentions=True) | |
attentions = [attention.to(torch.float32) for attention in output.attentions] | |
return attentions, tokens | |
model = load_model() | |
tokenizer = load_tokenizer() | |
st.title('Attention visualizer') | |
text = st.text_area('Write your text here and see attention weights.') | |
layer = st.slider( | |
'Which layer do you want to see?', | |
min_value=1, | |
max_value=model.config.num_hidden_layers | |
) - 1 | |
head = st.select_slider( | |
'Which head do you want to see?', | |
options = ['Average'] + list(range(1, model.config.num_attention_heads + 1)) | |
) | |
if text: | |
attentions, tokens = get_attention_weights_and_tokens(text) | |
if head == 'Average': | |
weights = attentions[layer].cpu()[0].mean(dim=0) | |
else: | |
weights = attentions[layer].cpu()[0][head - 1] | |
fig = px.imshow( | |
weights, | |
) | |
fig.update_layout(xaxis={ | |
'ticktext': tokens, | |
'tickvals': list(range(len(tokens))), | |
}, yaxis={ | |
'ticktext': tokens, | |
'tickvals': list(range(len(tokens))), | |
}, | |
height=800, | |
) | |
st.plotly_chart(fig) |