s-a-malik commited on
Commit
3067e7b
·
1 Parent(s): cdf250a
Files changed (2) hide show
  1. app.py +5 -3
  2. app_sep.py +0 -183
app.py CHANGED
@@ -331,15 +331,15 @@ def generate(
331
  # print(token_embeddings.numpy()[layer_range].shape)
332
  concat_layers = token_embeddings.numpy()[layer_range[0]:layer_range[1]].reshape(-1) # (num_layers * hidden_size)
333
  # print(concat_layers.shape)
334
- # or prob?
335
- probe_pred = probe.predict_log_proba(concat_layers.reshape(1, -1))[0][1] # prob of high SE
336
  # print(probe_pred.shape, probe_pred)
337
  # decode one token at a time
338
  output_id = outputs.sequences[0, input_ids.shape[1]+i]
339
  output_word = tokenizer.decode(output_id)
340
  print(output_id, output_word, probe_pred)
341
  new_highlighted_text = highlight_text(output_word, probe_pred)
342
- highlighted_text += new_highlighted_text
343
 
344
  yield highlighted_text
345
 
@@ -359,6 +359,8 @@ def highlight_text(text: str, uncertainty_score: float) -> str:
359
  return '<span style="background-color: {}; color: black">{}</span>'.format(
360
  html_color, text
361
  )
 
 
362
  chat_interface = gr.ChatInterface(
363
  fn=generate,
364
  additional_inputs=[
 
331
  # print(token_embeddings.numpy()[layer_range].shape)
332
  concat_layers = token_embeddings.numpy()[layer_range[0]:layer_range[1]].reshape(-1) # (num_layers * hidden_size)
333
  # print(concat_layers.shape)
334
+ # pred in range [-1, 1]
335
+ probe_pred = probe.predict_proba(concat_layers.reshape(1, -1))[0][1] * 2 - 1 # prob of high SE
336
  # print(probe_pred.shape, probe_pred)
337
  # decode one token at a time
338
  output_id = outputs.sequences[0, input_ids.shape[1]+i]
339
  output_word = tokenizer.decode(output_id)
340
  print(output_id, output_word, probe_pred)
341
  new_highlighted_text = highlight_text(output_word, probe_pred)
342
+ highlighted_text += f" {new_highlighted_text}"
343
 
344
  yield highlighted_text
345
 
 
359
  return '<span style="background-color: {}; color: black">{}</span>'.format(
360
  html_color, text
361
  )
362
+
363
+
364
  chat_interface = gr.ChatInterface(
365
  fn=generate,
366
  additional_inputs=[
app_sep.py DELETED
@@ -1,183 +0,0 @@
1
- import os
2
- import pickle as pkl
3
- from pathlib import Path
4
- from threading import Thread
5
- from typing import List, Optional, Tuple, Iterator
6
-
7
- import spaces
8
- import gradio as gr
9
- import numpy as np
10
- import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
-
13
-
14
- MAX_MAX_NEW_TOKENS = 2048
15
- DEFAULT_MAX_NEW_TOKENS = 1024
16
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "4096"))
17
-
18
- DESCRIPTION = """\
19
- # Llama-2 7B Chat with Streamable Semantic Uncertainty Probe
20
- This Space demonstrates the Llama-2-7b-chat model with an added semantic uncertainty probe.
21
- The highlighted text shows the model's uncertainty in real-time, with more intense yellow indicating higher uncertainty.
22
- """
23
-
24
- if torch.cuda.is_available():
25
- model_id = "meta-llama/Llama-2-7b-chat-hf"
26
- # TODO load the full model?
27
- model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
28
- tokenizer = AutoTokenizer.from_pretrained(model_id)
29
- tokenizer.use_default_system_prompt = False
30
-
31
- # load the probe data
32
- # TODO load accuracy and SE probe and compare in different tabs
33
- with open("./model/20240625-131035_demo.pkl", "rb") as f:
34
- probe_data = pkl.load(f)
35
- # take the NQ open one
36
- probe_data = probe_data[-2]
37
- probe = probe_data['t_bmodel']
38
- layer_range = probe_data['sep_layer_range']
39
- acc_probe = probe_data['t_amodel']
40
- acc_layer_range = probe_data['ap_layer_range']
41
-
42
- @spaces.GPU
43
- def generate(
44
- message: str,
45
- chat_history: List[Tuple[str, str]],
46
- system_prompt: str,
47
- max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
48
- temperature: float = 0.6,
49
- top_p: float = 0.9,
50
- top_k: int = 50,
51
- repetition_penalty: float = 1.2,
52
- ) -> Iterator[str]:
53
- conversation = []
54
- if system_prompt:
55
- conversation.append({"role": "system", "content": system_prompt})
56
- for user, assistant in chat_history:
57
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
58
- conversation.append({"role": "user", "content": message})
59
-
60
- input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
61
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
62
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
63
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
64
- input_ids = input_ids.to(model.device)
65
-
66
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
67
- generation_kwargs = dict(
68
- input_ids=input_ids,
69
- max_new_tokens=max_new_tokens,
70
- do_sample=True,
71
- top_p=top_p,
72
- top_k=top_k,
73
- temperature=temperature,
74
- repetition_penalty=repetition_penalty,
75
- streamer=streamer,
76
- output_hidden_states=True,
77
- return_dict_in_generate=True,
78
- )
79
-
80
- # Generate without threading
81
- with torch.no_grad():
82
- outputs = model.generate(**generation_kwargs)
83
- print(outputs.sequences.shape, input_ids.shape)
84
- generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
85
- print("Generated tokens:", generated_tokens, generated_tokens.shape)
86
- generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
87
- print("Generated text:", generated_text)
88
- # hidden states
89
- hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
90
- print(len(hidden))
91
- print(len(hidden[1])) # layers
92
- print(hidden[1][0].shape) # (sequence length, hidden size)
93
- # stack token embeddings
94
-
95
- # TODO do this loop on the fly instead of waiting for the whole generation
96
- highlighted_text = ""
97
- for i in range(1, len(hidden)):
98
- token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]) # (num_layers, hidden_size)
99
- # print(token_embeddings.shape)
100
- # probe the model
101
- # print(token_embeddings.numpy()[layer_range].shape)
102
- concat_layers = token_embeddings.numpy()[layer_range[0]:layer_range[1]].reshape(-1) # (num_layers * hidden_size)
103
- # print(concat_layers.shape)
104
- # or prob?
105
- probe_pred = probe.predict_log_proba(concat_layers.reshape(1, -1))[0][1] # prob of high SE
106
- # print(probe_pred.shape, probe_pred)
107
- # decode one token at a time
108
- output_id = outputs.sequences[0, input_ids.shape[1]+i]
109
- print(output_id, output_word, probe_pred)
110
- output_word = tokenizer.decode(output_id)
111
- new_highlighted_text = highlight_text(output_word, probe_pred)
112
- highlighted_text += new_highlighted_text
113
-
114
- yield highlighted_text
115
-
116
- def highlight_text(text: str, uncertainty_score: float) -> str:
117
- if uncertainty_score > 0:
118
- html_color = "#%02X%02X%02X" % (
119
- 255,
120
- int(255 * (1 - uncertainty_score)),
121
- int(255 * (1 - uncertainty_score)),
122
- )
123
- else:
124
- html_color = "#%02X%02X%02X" % (
125
- int(255 * (1 + uncertainty_score)),
126
- 255,
127
- int(255 * (1 + uncertainty_score)),
128
- )
129
- return '<span style="background-color: {}; color: black">{}</span>'.format(
130
- html_color, text
131
- )
132
- chat_interface = gr.ChatInterface(
133
- fn=generate,
134
- additional_inputs=[
135
- gr.Textbox(label="System prompt", lines=6),
136
- gr.Slider(
137
- label="Max new tokens",
138
- minimum=1,
139
- maximum=MAX_MAX_NEW_TOKENS,
140
- step=1,
141
- value=DEFAULT_MAX_NEW_TOKENS,
142
- ),
143
- gr.Slider(
144
- label="Temperature",
145
- minimum=0.1,
146
- maximum=4.0,
147
- step=0.1,
148
- value=0.6,
149
- ),
150
- gr.Slider(
151
- label="Top-p (nucleus sampling)",
152
- minimum=0.05,
153
- maximum=1.0,
154
- step=0.05,
155
- value=0.9,
156
- ),
157
- gr.Slider(
158
- label="Top-k",
159
- minimum=1,
160
- maximum=1000,
161
- step=1,
162
- value=50,
163
- ),
164
- gr.Slider(
165
- label="Repetition penalty",
166
- minimum=1.0,
167
- maximum=2.0,
168
- step=0.05,
169
- value=1.2,
170
- ),
171
- ],
172
- stop_btn=None,
173
- examples=[
174
- ["What is the capital of France?"],
175
- ["Explain the theory of relativity in simple terms."],
176
- ["Write a short poem about artificial intelligence."]
177
- ],
178
- title="Llama-2 7B Chat with Streamable Semantic Uncertainty Probe",
179
- description=DESCRIPTION,
180
- )
181
-
182
- if __name__ == "__main__":
183
- chat_interface.launch()