s-a-malik commited on
Commit
16c3a1a
·
1 Parent(s): 318934a

remove streaming

Browse files
Files changed (1) hide show
  1. app.py +75 -150
app.py CHANGED
@@ -54,80 +54,6 @@ else:
54
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
55
 
56
 
57
-
58
- class CustomStreamer(BaseStreamer):
59
- def __init__(self, skip_prompt: bool = False, timeout: Optional[float] = None):
60
- self.skip_prompt = skip_prompt
61
- self.timeout = timeout
62
-
63
- self.token_queue = Queue()
64
- self.hidden_states_queue = Queue()
65
- self.stop_signal = None
66
-
67
- self.next_tokens_are_prompt = True
68
-
69
- def put(self, value):
70
- """Receives tokens and adds them to the token queue."""
71
- if len(value.shape) > 1 and value.shape[0] > 1:
72
- raise ValueError("CustomStreamer only supports batch size 1")
73
- elif len(value.shape) > 1:
74
- value = value[0]
75
-
76
- if self.skip_prompt and self.next_tokens_are_prompt:
77
- self.next_tokens_are_prompt = False
78
- return
79
-
80
- for token in value.tolist():
81
- self.token_queue.put(token, timeout=self.timeout)
82
-
83
- def put_hidden_states(self, hidden_states):
84
- """Receives hidden states and adds them to the hidden states queue."""
85
- self.hidden_states_queue.put(hidden_states, timeout=self.timeout)
86
-
87
- def end(self):
88
- """Signals the end of the stream."""
89
- self.next_tokens_are_prompt = True
90
- self.token_queue.put(self.stop_signal, timeout=self.timeout)
91
- self.hidden_states_queue.put(self.stop_signal, timeout=self.timeout)
92
-
93
- def __iter__(self):
94
- return self
95
-
96
- def __next__(self):
97
- token = self.token_queue.get(timeout=self.timeout)
98
- if token == self.stop_signal:
99
- raise StopIteration()
100
- else:
101
- return token
102
-
103
- # Streamer claude
104
- # def generate(
105
- # message: str,
106
- # system_prompt: str,
107
- # chat_history: List[Tuple[str, str]],
108
- # max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS,
109
- # temperature: float = 0.6,
110
- # top_p: float = 0.9,
111
- # top_k: int = 50,
112
- # repetition_penalty: float = 1.2,
113
- # ) -> Iterator[Tuple[str, str]]:
114
- # conversation = []
115
- # if system_prompt:
116
- # conversation.append({"role": "system", "content": system_prompt})
117
- # for user, assistant in chat_history:
118
- # conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
119
- # conversation.append({"role": "user", "content": message})
120
-
121
- # input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
122
- # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
123
- # input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
124
- # gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
125
- # input_ids = input_ids.to(model.device)
126
-
127
-
128
-
129
-
130
-
131
  @spaces.GPU
132
  def generate(
133
  message: str,
@@ -150,36 +76,82 @@ def generate(
150
  input_ids = input_ids.to(model.device)
151
 
152
 
153
- streamer = CustomStreamer(skip_prompt=True, timeout=1000.0)
154
 
155
- def generate_with_states():
156
- with torch.no_grad():
157
- model.generate(
158
- input_ids=input_ids,
159
- max_new_tokens=max_new_tokens,
160
- do_sample=True,
161
- top_p=top_p,
162
- top_k=top_k,
163
- temperature=temperature,
164
- repetition_penalty=repetition_penalty,
165
- output_hidden_states=True,
166
- return_dict_in_generate=True,
167
- streamer=streamer
168
- )
 
 
 
169
 
170
- thread = Thread(target=generate_with_states)
171
- thread.start()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
  se_highlighted_text = ""
174
  acc_highlighted_text = ""
175
- for token_id in streamer:
176
- print(token_id)
177
- hidden_states = streamer.hidden_states_queue.get()
178
- if hidden_states is streamer.stop_signal:
179
- break
180
 
181
  # Semantic Uncertainty Probe
182
- token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
183
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
184
  se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
185
 
@@ -187,63 +159,16 @@ def generate(
187
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
188
  acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
189
 
190
- # decode latest token
191
- new_text = tokenizer.decode(token_id)
192
- print(new_text, se_probe_pred, acc_probe_pred)
193
 
194
- se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
195
- acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
196
  se_highlighted_text += f" {se_new_highlighted_text}"
197
  acc_highlighted_text += f" {acc_new_highlighted_text}"
198
 
199
- yield se_highlighted_text, acc_highlighted_text
200
-
201
- #### Generate without threading
202
- # generation_kwargs = dict(
203
- # input_ids=input_ids,
204
- # max_new_tokens=max_new_tokens,
205
- # do_sample=True,
206
- # top_p=top_p,
207
- # top_k=top_k,
208
- # temperature=temperature,
209
- # repetition_penalty=repetition_penalty,
210
- # streamer=streamer,
211
- # output_hidden_states=True,
212
- # return_dict_in_generate=True,
213
- # )
214
- # with torch.no_grad():
215
- # outputs = model.generate(**generation_kwargs)
216
- # generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
217
- # generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
218
- # # hidden states
219
- # hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
220
-
221
- # # TODO do this loop on the fly instead of waiting for the whole generation
222
- # se_highlighted_text = ""
223
- # acc_highlighted_text = ""
224
-
225
- # for i in range(1, len(hidden)):
226
-
227
- # # Semantic Uncertainty Probe
228
- # token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)
229
- # se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
230
- # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
231
-
232
- # # Accuracy Probe
233
- # # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
234
- # acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
235
- # acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
236
-
237
- # output_id = outputs.sequences[0, input_ids.shape[1]+i]
238
- # output_word = tokenizer.decode(output_id)
239
- # print(output_id, output_word, se_probe_pred, acc_probe_pred)
240
-
241
- # se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
242
- # acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
243
- # se_highlighted_text += f" {se_new_highlighted_text}"
244
- # acc_highlighted_text += f" {acc_new_highlighted_text}"
245
-
246
- # return se_highlighted_text, acc_highlighted_text
247
 
248
 
249
 
 
54
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
55
 
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  @spaces.GPU
58
  def generate(
59
  message: str,
 
76
  input_ids = input_ids.to(model.device)
77
 
78
 
79
+ # streamer = CustomStreamer(skip_prompt=True, timeout=10.0)
80
 
81
+ # def generate_with_states():
82
+ # with torch.no_grad():
83
+ # model.generate(
84
+ # input_ids=input_ids,
85
+ # max_new_tokens=max_new_tokens,
86
+ # do_sample=True,
87
+ # top_p=top_p,
88
+ # top_k=top_k,
89
+ # temperature=temperature,
90
+ # repetition_penalty=repetition_penalty,
91
+ # output_hidden_states=True,
92
+ # return_dict_in_generate=True,
93
+ # streamer=streamer
94
+ # )
95
+
96
+ # thread = Thread(target=generate_with_states)
97
+ # thread.start()
98
 
99
+ # se_highlighted_text = ""
100
+ # acc_highlighted_text = ""
101
+ # for token_id in streamer:
102
+ # print
103
+ # hidden_states = streamer.hidden_states_queue.get()
104
+ # if hidden_states is streamer.stop_signal:
105
+ # break
106
+
107
+ # # Semantic Uncertainty Probe
108
+ # token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
109
+ # se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
110
+ # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
111
+
112
+ # # Accuracy Probe
113
+ # acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
114
+ # acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
115
+
116
+ # # decode latest token
117
+ # new_text = tokenizer.decode(token_id)
118
+ # print(new_text, se_probe_pred, acc_probe_pred)
119
+
120
+ # se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
121
+ # acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
122
+ # se_highlighted_text += f" {se_new_highlighted_text}"
123
+ # acc_highlighted_text += f" {acc_new_highlighted_text}"
124
+
125
+ # yield se_highlighted_text, acc_highlighted_text
126
+
127
+ #### Generate without threading
128
+ generation_kwargs = dict(
129
+ input_ids=input_ids,
130
+ max_new_tokens=max_new_tokens,
131
+ do_sample=True,
132
+ top_p=top_p,
133
+ top_k=top_k,
134
+ temperature=temperature,
135
+ repetition_penalty=repetition_penalty,
136
+ streamer=streamer,
137
+ output_hidden_states=True,
138
+ return_dict_in_generate=True,
139
+ )
140
+ with torch.no_grad():
141
+ outputs = model.generate(**generation_kwargs)
142
+ generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
143
+ generated_text = tokenizer.decode(generated_tokens, skip_special_tokens=True)
144
+ print(generated_text)
145
+ # hidden states
146
+ hidden = outputs.hidden_states # list of tensors, one for each token, then (batch size, sequence length, hidden size)
147
 
148
  se_highlighted_text = ""
149
  acc_highlighted_text = ""
150
+
151
+ for i in range(1, len(hidden)):
 
 
 
152
 
153
  # Semantic Uncertainty Probe
154
+ token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden[i]]).numpy() # (num_layers, hidden_size)
155
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
156
  se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
157
 
 
159
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
160
  acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
161
 
162
+ output_id = outputs.sequences[0, input_ids.shape[1]+i]
163
+ output_word = tokenizer.decode(output_id)
164
+ print(output_id, output_word, se_probe_pred, acc_probe_pred)
165
 
166
+ se_new_highlighted_text = highlight_text(output_word, se_probe_pred)
167
+ acc_new_highlighted_text = highlight_text(output_word, acc_probe_pred)
168
  se_highlighted_text += f" {se_new_highlighted_text}"
169
  acc_highlighted_text += f" {acc_new_highlighted_text}"
170
 
171
+ return se_highlighted_text, acc_highlighted_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
 
174