Spaces:
Sleeping
Sleeping
Add plotting functionality
Browse files
app.py
CHANGED
@@ -1,7 +1,41 @@
|
|
1 |
import torch
|
2 |
import streamlit as st
|
|
|
|
|
|
|
3 |
from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, GenerationConfig, AutoModelForCausalLM
|
4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
st.title('How do LLM choose their words?')
|
6 |
|
7 |
col1, col2 = st.columns(2)
|
@@ -65,4 +99,5 @@ if instruction:
|
|
65 |
|
66 |
st.write(output_text)
|
67 |
|
68 |
-
|
|
|
|
1 |
import torch
|
2 |
import streamlit as st
|
3 |
+
import numpy as np
|
4 |
+
import plotly.express as px, plotly.graph_objects as go
|
5 |
+
from plotly.subplots import make_subplots
|
6 |
from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, GenerationConfig, AutoModelForCausalLM
|
7 |
|
8 |
+
def top_token_ids(outputs, threshold=-np.inf):
|
9 |
+
"Returns the index of the tokens whose score exceeds a threshold, for each output step"
|
10 |
+
indexes = []
|
11 |
+
for tensor in outputs['scores']:
|
12 |
+
candidates = np.argwhere(tensor.flatten().cpu() > threshold).numpy()[0]
|
13 |
+
ordering_mask = np.argsort(tensor[0][candidates].cpu())
|
14 |
+
candidates = candidates[ordering_mask]
|
15 |
+
if not isinstance(candidates, np.ndarray):
|
16 |
+
indexes.append(np.array([candidates]))
|
17 |
+
else:
|
18 |
+
indexes.append(candidates)
|
19 |
+
return indexes
|
20 |
+
|
21 |
+
def plot_word_scores(top_token_ids, outputs, tokenizer, boolq=False, width=600):
|
22 |
+
fig = make_subplots(rows=len(top_token_ids), cols=1)
|
23 |
+
for step, candidates in enumerate(top_token_ids):
|
24 |
+
fig.append_trace(
|
25 |
+
go.Bar(
|
26 |
+
y=tokenizer.convert_ids_to_tokens(candidates),
|
27 |
+
x=outputs['scores'][step][0][candidates].cpu(),
|
28 |
+
orientation='h'
|
29 |
+
),
|
30 |
+
row=step+1, col=1
|
31 |
+
)
|
32 |
+
fig.update_layout(
|
33 |
+
width=500,
|
34 |
+
height=300*len(top_token_ids),
|
35 |
+
showlegend=False
|
36 |
+
)
|
37 |
+
return fig
|
38 |
+
|
39 |
st.title('How do LLM choose their words?')
|
40 |
|
41 |
col1, col2 = st.columns(2)
|
|
|
99 |
|
100 |
st.write(output_text)
|
101 |
|
102 |
+
fig = plot_word_scores(top_token_ids(outputs, threshold=-1), outputs, tokenizer)
|
103 |
+
st.plotly_chart(fig, use_container_width=False)
|