s-a-malik commited on
Commit
8aba6d1
·
1 Parent(s): 32936b7

add accuracy probe

Browse files
Files changed (1) hide show
  1. app.py +214 -66
app.py CHANGED
@@ -3,6 +3,7 @@ import pickle as pkl
3
  from pathlib import Path
4
  from threading import Thread
5
  from typing import List, Tuple, Iterator
 
6
 
7
  import spaces
8
  import gradio as gr
@@ -14,11 +15,22 @@ MAX_MAX_NEW_TOKENS = 2048
14
  DEFAULT_MAX_NEW_TOKENS = 1024
15
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
16
 
17
- DESCRIPTION = """\
18
- This Space demonstrates the Llama-2-7b-chat model with a semantic uncertainty probe.
19
- The highlighted text shows the model's uncertainty in real-time, with green indicating more certain generations and red indicating higher uncertainty.
 
 
 
 
 
20
  """
21
 
 
 
 
 
 
 
22
  if torch.cuda.is_available():
23
  model_id = "meta-llama/Llama-2-7b-chat-hf"
24
  # TODO load the full model not the 8bit one?
@@ -32,10 +44,91 @@ if torch.cuda.is_available():
32
  probe_data = pkl.load(f)
33
  # take the NQ open one
34
  probe_data = probe_data[-2]
35
- probe = probe_data['t_bmodel']
36
- layer_range = probe_data['sep_layer_range']
37
  acc_probe = probe_data['t_amodel']
38
  acc_layer_range = probe_data['ap_layer_range']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  @spaces.GPU
41
  def generate(
@@ -84,20 +177,31 @@ def generate(
84
  hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
85
 
86
  # TODO do this loop on the fly instead of waiting for the whole generation
87
- highlighted_text = ""
 
88
  for i in range(1, len(hidden)):
89
- token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]) # (num_layers, hidden_size)
90
- concat_layers = token_embeddings.numpy()[layer_range[0]:layer_range[1]].reshape(-1) # (num_layers * hidden_size)
91
- # pred in range [-1, 1]
92
- probe_pred = probe.predict_proba(concat_layers.reshape(1, -1))[0][1] * 2 - 1 # prob of high SE
93
- # decode one token at a time
 
 
 
 
 
 
94
  output_id = outputs.sequences[0, input_ids.shape[1]+i]
95
  output_word = tokenizer.decode(output_id)
96
- print(output_id, output_word, probe_pred)
97
- new_highlighted_text = highlight_text(output_word, probe_pred)
98
- highlighted_text += f" {new_highlighted_text}"
99
-
100
- yield highlighted_text
 
 
 
 
101
 
102
  def highlight_text(text: str, uncertainty_score: float) -> str:
103
  if uncertainty_score > 0:
@@ -116,56 +220,100 @@ def highlight_text(text: str, uncertainty_score: float) -> str:
116
  html_color, text
117
  )
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
- chat_interface = gr.ChatInterface(
121
- fn=generate,
122
- additional_inputs=[
123
- gr.Textbox(label="System prompt", lines=6),
124
- gr.Slider(
125
- label="Max new tokens",
126
- minimum=1,
127
- maximum=MAX_MAX_NEW_TOKENS,
128
- step=1,
129
- value=DEFAULT_MAX_NEW_TOKENS,
130
- ),
131
- gr.Slider(
132
- label="Temperature",
133
- minimum=0.1,
134
- maximum=4.0,
135
- step=0.1,
136
- value=0.6,
137
- ),
138
- gr.Slider(
139
- label="Top-p (nucleus sampling)",
140
- minimum=0.05,
141
- maximum=1.0,
142
- step=0.05,
143
- value=0.9,
144
- ),
145
- gr.Slider(
146
- label="Top-k",
147
- minimum=1,
148
- maximum=1000,
149
- step=1,
150
- value=50,
151
- ),
152
- gr.Slider(
153
- label="Repetition penalty",
154
- minimum=1.0,
155
- maximum=2.0,
156
- step=0.05,
157
- value=1.2,
158
- ),
159
- ],
160
- stop_btn=None,
161
- examples=[
162
- ["What is the capital of France?"],
163
- ["Explain the theory of relativity in simple terms."],
164
- ["Write a short poem about artificial intelligence."]
165
- ],
166
- title="Llama-2 7B Chat with Streamable Semantic Uncertainty Probe",
167
- description=DESCRIPTION,
168
- )
169
 
170
  if __name__ == "__main__":
171
- chat_interface.launch()
 
3
  from pathlib import Path
4
  from threading import Thread
5
  from typing import List, Tuple, Iterator
6
+ from queue import Queue
7
 
8
  import spaces
9
  import gradio as gr
 
15
  DEFAULT_MAX_NEW_TOKENS = 1024
16
  MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
 
18
+ DESCRIPTION = """
19
+ <h1>Llama-2 7B Chat with Uncertainty Probes</h1>
20
+ <p>This Space demonstrates the Llama-2-7b-chat model with a semantic uncertainty probe.</p>
21
+ <p>The highlighted text shows the model's uncertainty in real-time:</p>
22
+ <ul>
23
+ <li><span style="background-color: #00FF00; color: black">Green</span> indicates more certain generations</li>
24
+ <li><span style="background-color: #FF0000; color: black">Red</span> indicates more uncertain generations</li>
25
+ </ul>
26
  """
27
 
28
+ EXAMPLES = [
29
+ ["What is the capital of France?", "You are a helpful assistant.", []],
30
+ ["Explain the theory of relativity in simple terms.", "You are an expert physicist explaining concepts to a layman.", []],
31
+ ["Write a short poem about artificial intelligence.", "You are a creative poet with a interest in technology.", []]
32
+ ]
33
+
34
  if torch.cuda.is_available():
35
  model_id = "meta-llama/Llama-2-7b-chat-hf"
36
  # TODO load the full model not the 8bit one?
 
44
  probe_data = pkl.load(f)
45
  # take the NQ open one
46
  probe_data = probe_data[-2]
47
+ se_probe = probe_data['t_bmodel']
48
+ se_layer_range = probe_data['sep_layer_range']
49
  acc_probe = probe_data['t_amodel']
50
  acc_layer_range = probe_data['ap_layer_range']
51
+ else:
52
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
53
+
54
+
55
+ class CustomStreamer(TextIteratorStreamer):
56
+ """
57
+ Streamer to also store hidden states in a queue.
58
+ TODO check this works
59
+ """
60
+ def __init__(self, tokenizer, skip_prompt: bool = False, skip_special_tokens: bool = False, **decode_kwargs):
61
+ super().__init__(tokenizer, skip_prompt, skip_special_tokens, **decode_kwargs)
62
+ self.hidden_states_queue = Queue()
63
+
64
+ def put(self, value):
65
+ if isinstance(value, dict) and 'hidden_states' in value:
66
+ self.hidden_states_queue.put(value['hidden_states'])
67
+ super().put(value)
68
+
69
+ # Streamer claude
70
+ # def generate(
71
+ # message: str,
72
+ # system_prompt: str,
73
+ # chat_history: List[Tuple[str, str]],
74
+ # max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
75
+ # temperature: float = 0.6,
76
+ # top_p: float = 0.9,
77
+ # top_k: int = 50,
78
+ # repetition_penalty: float = 1.2,
79
+ # ) -> Iterator[Tuple[str, str]]:
80
+ # conversation = []
81
+ # if system_prompt:
82
+ # conversation.append({"role": "system", "content": system_prompt})
83
+ # for user, assistant in chat_history:
84
+ # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
85
+ # conversation.append({"role": "user", "content": message})
86
+
87
+ # input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
88
+ # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
89
+ # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
90
+ # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
91
+ # input_ids = input_ids.to(model.device)
92
+
93
+ # streamer = CustomStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
94
+ # generation_kwargs = dict(
95
+ # input_ids=input_ids,
96
+ # max_new_tokens=max_new_tokens,
97
+ # do_sample=True,
98
+ # top_p=top_p,
99
+ # top_k=top_k,
100
+ # temperature=temperature,
101
+ # repetition_penalty=repetition_penalty,
102
+ # streamer=streamer,
103
+ # output_hidden_states=True,
104
+ # return_dict_in_generate=True,
105
+ # )
106
+
107
+ # thread = Thread(target=model.generate, kwargs=generation_kwargs)
108
+ # thread.start()
109
+
110
+ # se_highlighted_text = ""
111
+ # acc_highlighted_text = ""
112
+ # for new_text in streamer:
113
+ # hidden_states = streamer.hidden_states_queue.get()
114
+
115
+ # # Semantic Uncertainty Probe
116
+ # se_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
117
+ # se_concat_layers = se_token_embeddings.numpy()[se_layer_range[0]:se_layer_range[1]].reshape(-1)
118
+ # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
119
+
120
+ # # Accuracy Probe
121
+ # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
122
+ # acc_concat_layers = acc_token_embeddings.numpy()[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
123
+ # acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1] * 2 - 1
124
+
125
+ # se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
126
+ # acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
127
+
128
+ # se_highlighted_text += se_new_highlighted_text
129
+ # acc_highlighted_text += acc_new_highlighted_text
130
+
131
+ # yield se_highlighted_text, acc_highlighted_text
132
 
133
  @spaces.GPU
134
  def generate(
 
177
  hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
178
 
179
  # TODO do this loop on the fly instead of waiting for the whole generation
180
+ se_highlighted_text = ""
181
+ acc_highlighted_text = ""
182
  for i in range(1, len(hidden)):
183
+
184
+ # Semantic Uncertainty Probe
185
+ token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)
186
+ se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
187
+ se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
188
+
189
+ # Accuracy Probe
190
+ # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
191
+ acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
192
+ acc_probe_pred = -1 * acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1] * 2 - 1
193
+
194
  output_id = outputs.sequences[0, input_ids.shape[1]+i]
195
  output_word = tokenizer.decode(output_id)
196
+ print(output_id, output_word, se_probe_pred, acc_probe_pred)
197
+
198
+ se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
199
+ acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
200
+ se_highlighted_text += f" {se_new_highlighted_text}"
201
+ acc_highlighted_text += f" {acc_new_highlighted_text}"
202
+
203
+ yield se_highlighted_text, acc_highlighted_text
204
+
205
 
206
  def highlight_text(text: str, uncertainty_score: float) -> str:
207
  if uncertainty_score > 0:
 
220
  html_color, text
221
  )
222
 
223
+ with gr.Blocks(title="Llama-2 7B Chat with Dual Probes", css="footer {visibility: hidden}") as demo:
224
+ gr.HTML(DESCRIPTION)
225
+
226
+ with gr.Row():
227
+ with gr.Column():
228
+ message = gr.Textbox(label="Message")
229
+ system_prompt = gr.Textbox(label="System prompt", lines=2)
230
+
231
+ with gr.Column():
232
+ max_new_tokens = gr.Slider(label="Max new tokens", minimum=1, maximum=MAX_MAX_NEW_TOKENS, step=1, value=DEFAULT_MAX_NEW_TOKENS)
233
+ temperature = gr.Slider(label="Temperature", minimum=0.1, maximum=4.0, step=0.1, value=0.6)
234
+ top_p = gr.Slider(label="Top-p (nucleus sampling)", minimum=0.05, maximum=1.0, step=0.05, value=0.9)
235
+ top_k = gr.Slider(label="Top-k", minimum=1, maximum=1000, step=1, value=50)
236
+ repetition_penalty = gr.Slider(label="Repetition penalty", minimum=1.0, maximum=2.0, step=0.05, value=1.2)
237
+
238
+ with gr.Row():
239
+ generate_btn = gr.Button("Generate")
240
+ # add spacing between probes and titles for each output
241
+ with gr.Row():
242
+ with gr.Column():
243
+ title = gr.HTML("<h2>Semantic Uncertainty Probe</h2>")
244
+ se_output = gr.HTML(label="Semantic Uncertainty Probe")
245
+ with gr.Column():
246
+ title = gr.HTML("<h2>Accuracy Probe</h2>")
247
+ acc_output = gr.HTML(label="Accuracy Probe")
248
+
249
+ chat_history = gr.State([])
250
+
251
+ # gr.Examples(
252
+ # examples=EXAMPLES,
253
+ # inputs=[message, system_prompt, chat_history],
254
+ # outputs=[se_output, acc_output],
255
+ # fn=generate,
256
+ # )
257
+
258
+ generate_btn.click(
259
+ generate,
260
+ inputs=[message, system_prompt, chat_history, max_new_tokens, temperature, top_p, top_k, repetition_penalty],
261
+ outputs=[se_output, acc_output]
262
+ )
263
+
264
+ # chat_interface = gr.ChatInterface(
265
+ # fn=generate,
266
+ # additional_inputs=[
267
+ # gr.Textbox(label="System prompt", lines=6),
268
+ # gr.Slider(
269
+ # label="Max new tokens",
270
+ # minimum=1,
271
+ # maximum=MAX_MAX_NEW_TOKENS,
272
+ # step=1,
273
+ # value=DEFAULT_MAX_NEW_TOKENS,
274
+ # ),
275
+ # gr.Slider(
276
+ # label="Temperature",
277
+ # minimum=0.1,
278
+ # maximum=4.0,
279
+ # step=0.1,
280
+ # value=0.6,
281
+ # ),
282
+ # gr.Slider(
283
+ # label="Top-p (nucleus sampling)",
284
+ # minimum=0.05,
285
+ # maximum=1.0,
286
+ # step=0.05,
287
+ # value=0.9,
288
+ # ),
289
+ # gr.Slider(
290
+ # label="Top-k",
291
+ # minimum=1,
292
+ # maximum=1000,
293
+ # step=1,
294
+ # value=50,
295
+ # ),
296
+ # gr.Slider(
297
+ # label="Repetition penalty",
298
+ # minimum=1.0,
299
+ # maximum=2.0,
300
+ # step=0.05,
301
+ # value=1.2,
302
+ # ),
303
+ # ],
304
+ # stop_btn=None,
305
+ # examples=[
306
+ # ["What is the capital of France?"],
307
+ # ["Who landed on the moon?"],
308
+ # ["Who is Yarin Gal?"]
309
+ # ],
310
+ # title="Llama-2 7B Chat with Streamable Semantic Uncertainty Probe",
311
+ # description=DESCRIPTION,
312
+ # )
313
+
314
+ # if __name__ == "__main__":
315
+ # chat_interface.launch()
316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
  if __name__ == "__main__":
319
+ demo.launch()