Spaces:
Sleeping
Sleeping
File size: 1,939 Bytes
16fc6d2 39de0ac 16fc6d2 0f406af 16fc6d2 7bdd4ad 16fc6d2 abdf1f2 16fc6d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import plotly.express as px
model_name = 'Qwen/Qwen2-1.5B'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
@st.cache_resource
def load_model():
return AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
token=st.secrets['hf_token']
).to(device)
@st.cache_resource
def load_tokenizer():
return AutoTokenizer.from_pretrained(
model_name,
token=st.secrets['hf_token']
)
@torch.no_grad()
@st.cache_data()
def get_attention_weights_and_tokens(text):
tokenized = tokenizer(text, return_tensors='pt')
tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]]
tokenized = tokenized.to(device)
output = model(**tokenized, output_attentions=True)
attentions = [attention.to(torch.float32) for attention in output.attentions]
return attentions, tokens
model = load_model()
tokenizer = load_tokenizer()
st.title('Attention visualizer')
text = st.text_area('Write your text here and see attention weights.')
layer = st.slider(
'Which layer do you want to see?',
min_value=1,
max_value=model.config.num_hidden_layers
) - 1
head = st.select_slider(
'Which head do you want to see?',
options = ['Average'] + list(range(1, model.config.num_attention_heads + 1))
)
if text:
attentions, tokens = get_attention_weights_and_tokens(text)
if head == 'Average':
weights = attentions[layer].cpu()[0].mean(dim=0)
else:
weights = attentions[layer].cpu()[0][head - 1]
fig = px.imshow(
weights,
)
fig.update_layout(xaxis={
'ticktext': tokens,
'tickvals': list(range(len(tokens))),
}, yaxis={
'ticktext': tokens,
'tickvals': list(range(len(tokens))),
},
height=800,
)
st.plotly_chart(fig) |