patpizio commited on
Commit
5cf0970
·
1 Parent(s): 31df035

Add plotting functionality

Browse files
Files changed (1) hide show
  1. app.py +36 -1
app.py CHANGED
@@ -1,7 +1,41 @@
1
  import torch
2
  import streamlit as st
 
 
 
3
  from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, GenerationConfig, AutoModelForCausalLM
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  st.title('How do LLM choose their words?')
6
 
7
  col1, col2 = st.columns(2)
@@ -65,4 +99,5 @@ if instruction:
65
 
66
  st.write(output_text)
67
 
68
- model.compute_transition_scores(outputs.sequences, outputs.scores, normalize_logits=False)
 
 
1
  import torch
2
  import streamlit as st
3
+ import numpy as np
4
+ import plotly.express as px, plotly.graph_objects as go
5
+ from plotly.subplots import make_subplots
6
  from transformers import AutoTokenizer, T5Tokenizer, T5ForConditionalGeneration, GenerationConfig, AutoModelForCausalLM
7
 
8
+ def top_token_ids(outputs, threshold=-np.inf):
9
+ "Returns the index of the tokens whose score exceeds a threshold, for each output step"
10
+ indexes = []
11
+ for tensor in outputs['scores']:
12
+ candidates = np.argwhere(tensor.flatten().cpu() > threshold).numpy()[0]
13
+ ordering_mask = np.argsort(tensor[0][candidates].cpu())
14
+ candidates = candidates[ordering_mask]
15
+ if not isinstance(candidates, np.ndarray):
16
+ indexes.append(np.array([candidates]))
17
+ else:
18
+ indexes.append(candidates)
19
+ return indexes
20
+
21
+ def plot_word_scores(top_token_ids, outputs, tokenizer, boolq=False, width=600):
22
+ fig = make_subplots(rows=len(top_token_ids), cols=1)
23
+ for step, candidates in enumerate(top_token_ids):
24
+ fig.append_trace(
25
+ go.Bar(
26
+ y=tokenizer.convert_ids_to_tokens(candidates),
27
+ x=outputs['scores'][step][0][candidates].cpu(),
28
+ orientation='h'
29
+ ),
30
+ row=step+1, col=1
31
+ )
32
+ fig.update_layout(
33
+ width=500,
34
+ height=300*len(top_token_ids),
35
+ showlegend=False
36
+ )
37
+ return fig
38
+
39
  st.title('How do LLM choose their words?')
40
 
41
  col1, col2 = st.columns(2)
 
99
 
100
  st.write(output_text)
101
 
102
+ fig = plot_word_scores(top_token_ids(outputs, threshold=-1), outputs, tokenizer)
103
+ st.plotly_chart(fig, use_container_width=False)