s-a-malik commited on
Commit
180088d
·
1 Parent(s): 0120475

basestreamer

Browse files
Files changed (1) hide show
  1. app.py +93 -47
app.py CHANGED
@@ -8,7 +8,7 @@ from queue import Queue
8
  import spaces
9
  import gradio as gr
10
  import torch
11
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
12
 
13
 
14
  MAX_MAX_NEW_TOKENS = 2048
@@ -53,19 +53,51 @@ else:
53
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
54
 
55
 
56
- class CustomStreamer(TextIteratorStreamer):
57
- """
58
- Streamer to also store hidden states in a queue.
59
- TODO check this works
60
- """
61
- def __init__(self, tokenizer, skip_prompt: bool = False, skip_special_tokens: bool = False, **decode_kwargs):
62
- super().__init__(tokenizer, skip_prompt, skip_special_tokens, **decode_kwargs)
63
  self.hidden_states_queue = Queue()
 
 
 
64
 
65
  def put(self, value):
66
- if isinstance(value, dict) and 'hidden_states' in value:
67
- self.hidden_states_queue.put(value['hidden_states'])
68
- super().put(value)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  # Streamer claude
71
  # def generate(
@@ -116,27 +148,56 @@ def generate(
116
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
117
  input_ids = input_ids.to(model.device)
118
 
119
- # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
120
- streamer = CustomStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
121
- generation_kwargs = dict(
122
- input_ids=input_ids,
123
- max_new_tokens=max_new_tokens,
124
- do_sample=True,
125
- top_p=top_p,
126
- top_k=top_k,
127
- temperature=temperature,
128
- repetition_penalty=repetition_penalty,
129
- streamer=streamer,
130
- output_hidden_states=True,
131
- return_dict_in_generate=True,
132
- )
133
- # with threading
134
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
 
 
 
135
  thread.start()
 
136
  se_highlighted_text = ""
137
  acc_highlighted_text = ""
138
- for new_text in streamer:
139
  hidden_states = streamer.hidden_states_queue.get()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  # Semantic Uncertainty Probe
141
  token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
142
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
@@ -146,6 +207,8 @@ def generate(
146
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
147
  acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
148
 
 
 
149
  print(new_text, se_probe_pred, acc_probe_pred)
150
 
151
  se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
@@ -155,25 +218,9 @@ def generate(
155
 
156
  yield se_highlighted_text, acc_highlighted_text
157
 
158
- # Semantic Uncertainty Probe
159
- # se_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
160
- # se_concat_layers = se_token_embeddings.numpy()[se_layer_range[0]:se_layer_range[1]].reshape(-1)
161
- # se_probe_pred = se_probe.predict_proba(se_concat_layers.reshape(1, -1))[0][1] * 2 - 1
162
-
163
- # # Accuracy Probe
164
- # acc_token_embeddings = torch.stack([layer[0, -1, :].cpu() for layer in hidden_states])
165
- # acc_concat_layers = acc_token_embeddings.numpy()[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
166
- # acc_probe_pred = acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1] * 2 - 1
167
-
168
- # se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
169
- # acc_new_highlighted_text = highlight_text(new_text, acc_probe_pred)
170
-
171
- # se_highlighted_text += se_new_highlighted_text
172
- # acc_highlighted_text += acc_new_highlighted_text
173
-
174
- # yield se_highlighted_text, acc_highlighted_text
175
 
176
- # Generate without threading
177
  # with torch.no_grad():
178
  # outputs = model.generate(**generation_kwargs)
179
  # generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
@@ -206,7 +253,6 @@ def generate(
206
  # se_highlighted_text += f" {se_new_highlighted_text}"
207
  # acc_highlighted_text += f" {acc_new_highlighted_text}"
208
 
209
- # # yield se_highlighted_text, acc_highlighted_text
210
  # return se_highlighted_text, acc_highlighted_text
211
 
212
 
 
8
  import spaces
9
  import gradio as gr
10
  import torch
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BaseStreamer
12
 
13
 
14
  MAX_MAX_NEW_TOKENS = 2048
 
53
  DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"
54
 
55
 
56
+
57
+ class CustomStreamer(BaseStreamer):
58
+ def __init__(self, skip_prompt: bool = False, timeout: Optional[float] = None):
59
+ self.skip_prompt = skip_prompt
60
+ self.timeout = timeout
61
+
62
+ self.token_queue = Queue()
63
  self.hidden_states_queue = Queue()
64
+ self.stop_signal = None
65
+
66
+ self.next_tokens_are_prompt = True
67
 
68
  def put(self, value):
69
+ """Receives tokens and adds them to the token queue."""
70
+ if len(value.shape) > 1 and value.shape[0] > 1:
71
+ raise ValueError("CustomStreamer only supports batch size 1")
72
+ elif len(value.shape) > 1:
73
+ value = value[0]
74
+
75
+ if self.skip_prompt and self.next_tokens_are_prompt:
76
+ self.next_tokens_are_prompt = False
77
+ return
78
+
79
+ for token in value.tolist():
80
+ self.token_queue.put(token, timeout=self.timeout)
81
+
82
+ def put_hidden_states(self, hidden_states):
83
+ """Receives hidden states and adds them to the hidden states queue."""
84
+ self.hidden_states_queue.put(hidden_states, timeout=self.timeout)
85
+
86
+ def end(self):
87
+ """Signals the end of the stream."""
88
+ self.next_tokens_are_prompt = True
89
+ self.token_queue.put(self.stop_signal, timeout=self.timeout)
90
+ self.hidden_states_queue.put(self.stop_signal, timeout=self.timeout)
91
+
92
+ def __iter__(self):
93
+ return self
94
+
95
+ def __next__(self):
96
+ token = self.token_queue.get(timeout=self.timeout)
97
+ if token == self.stop_signal:
98
+ raise StopIteration()
99
+ else:
100
+ return token
101
 
102
  # Streamer claude
103
  # def generate(
 
148
  gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
149
  input_ids = input_ids.to(model.device)
150
 
151
+
152
+ streamer = CustomStreamer(skip_prompt=True, timeout=10.0)
153
+
154
+ def generate_with_states():
155
+ with torch.no_grad():
156
+ model.generate(
157
+ input_ids=input_ids,
158
+ max_new_tokens=max_new_tokens,
159
+ do_sample=True,
160
+ top_p=top_p,
161
+ top_k=top_k,
162
+ temperature=temperature,
163
+ repetition_penalty=repetition_penalty,
164
+ output_hidden_states=True,
165
+ return_dict_in_generate=True,
166
+ streamer=streamer
167
+ )
168
+
169
+ thread = Thread(target=generate_with_states)
170
  thread.start()
171
+
172
  se_highlighted_text = ""
173
  acc_highlighted_text = ""
174
+ for token_id in streamer:
175
  hidden_states = streamer.hidden_states_queue.get()
176
+ if hidden_states is streamer.stop_signal:
177
+ break
178
+
179
+ # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
180
+ # streamer = CustomStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
181
+ # generation_kwargs = dict(
182
+ # input_ids=input_ids,
183
+ # max_new_tokens=max_new_tokens,
184
+ # do_sample=True,
185
+ # top_p=top_p,
186
+ # top_k=top_k,
187
+ # temperature=temperature,
188
+ # repetition_penalty=repetition_penalty,
189
+ # streamer=streamer,
190
+ # output_hidden_states=True,
191
+ # return_dict_in_generate=True,
192
+ # )
193
+ # #### with threading
194
+ # thread = Thread(target=model.generate, kwargs=generation_kwargs)
195
+ # thread.start()
196
+ # se_highlighted_text = ""
197
+ # acc_highlighted_text = ""
198
+
199
+ # for new_text in streamer:
200
+ # hidden_states = streamer.hidden_states_queue.get()
201
  # Semantic Uncertainty Probe
202
  token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
203
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
 
207
  acc_concat_layers = token_embeddings[acc_layer_range[0]:acc_layer_range[1]].reshape(-1)
208
  acc_probe_pred = (1 - acc_probe.predict_proba(acc_concat_layers.reshape(1, -1))[0][1]) * 2 - 1
209
 
210
+ # decode latest token
211
+ new_test = tokenizer.decode(token_id)
212
  print(new_text, se_probe_pred, acc_probe_pred)
213
 
214
  se_new_highlighted_text = highlight_text(new_text, se_probe_pred)
 
218
 
219
  yield se_highlighted_text, acc_highlighted_text
220
 
221
+ thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ #### Generate without threading
224
  # with torch.no_grad():
225
  # outputs = model.generate(**generation_kwargs)
226
  # generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
 
253
  # se_highlighted_text += f" {se_new_highlighted_text}"
254
  # acc_highlighted_text += f" {acc_new_highlighted_text}"
255
 
 
256
  # return se_highlighted_text, acc_highlighted_text
257
 
258