File size: 1,939 Bytes
16fc6d2
 
 
 
 
 
39de0ac
16fc6d2
 
 
 
0f406af
 
 
 
 
16fc6d2
 
 
7bdd4ad
 
 
 
16fc6d2
 
 
 
 
 
 
 
abdf1f2
 
16fc6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import streamlit as st
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import plotly.express as px


model_name = 'Qwen/Qwen2-1.5B'
device = 'cuda' if torch.cuda.is_available() else 'cpu'

@st.cache_resource
def load_model():
    return AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        token=st.secrets['hf_token']
    ).to(device)

@st.cache_resource
def load_tokenizer():
    return AutoTokenizer.from_pretrained(
        model_name,
        token=st.secrets['hf_token']
    )

@torch.no_grad()
@st.cache_data()
def get_attention_weights_and_tokens(text):
    tokenized = tokenizer(text, return_tensors='pt')
    tokens = [tokenizer.decode(token) for token in tokenized.input_ids[0]]
    tokenized = tokenized.to(device)
    output = model(**tokenized, output_attentions=True)
    attentions = [attention.to(torch.float32) for attention in output.attentions]
    return attentions, tokens

model = load_model()
tokenizer = load_tokenizer()

st.title('Attention visualizer')
text = st.text_area('Write your text here and see attention weights.')
layer = st.slider(
    'Which layer do you want to see?',
    min_value=1,
    max_value=model.config.num_hidden_layers
) - 1

head = st.select_slider(
    'Which head do you want to see?',
    options = ['Average'] + list(range(1, model.config.num_attention_heads + 1))
)
if text:
    attentions, tokens = get_attention_weights_and_tokens(text)
    if head == 'Average':
        weights = attentions[layer].cpu()[0].mean(dim=0)
    else:
        weights = attentions[layer].cpu()[0][head - 1]
    fig = px.imshow(
        weights,
    )
    fig.update_layout(xaxis={
            'ticktext': tokens,
            'tickvals': list(range(len(tokens))),
        }, yaxis={
            'ticktext': tokens,
            'tickvals': list(range(len(tokens))),
        },
        height=800,
    )

    st.plotly_chart(fig)