s-a-malik commited on
Commit
318934a
·
1 Parent(s): 75a3efc

higher timeout

Browse files
Files changed (1) hide show
  1. app.py +14 -25
app.py CHANGED
@@ -150,7 +150,7 @@ def generate(
150
  input_ids = input_ids.to(model.device)
151
 
152
 
153
- streamer = CustomStreamer(skip_prompt=True, timeout=10.0)
154
 
155
  def generate_with_states():
156
  with torch.no_grad():
@@ -173,32 +173,11 @@ def generate(
173
  se_highlighted_text = ""
174
  acc_highlighted_text = ""
175
  for token_id in streamer:
 
176
  hidden_states = streamer.hidden_states_queue.get()
177
  if hidden_states is streamer.stop_signal:
178
  break
179
 
180
- # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
181
- # streamer = CustomStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
182
- # generation_kwargs = dict(
183
- # input_ids=input_ids,
184
- # max_new_tokens=max_new_tokens,
185
- # do_sample=True,
186
- # top_p=top_p,
187
- # top_k=top_k,
188
- # temperature=temperature,
189
- # repetition_penalty=repetition_penalty,
190
- # streamer=streamer,
191
- # output_hidden_states=True,
192
- # return_dict_in_generate=True,
193
- # )
194
- # #### with threading
195
- # thread = Thread(target=model.generate, kwargs=generation_kwargs)
196
- # thread.start()
197
- # se_highlighted_text = ""
198
- # acc_highlighted_text = ""
199
-
200
- # for new_text in streamer:
201
- # hidden_states = streamer.hidden_states_queue.get()
202
  # Semantic Uncertainty Probe
203
  token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
204
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
@@ -219,9 +198,19 @@ def generate(
219
 
220
  yield se_highlighted_text, acc_highlighted_text
221
 
222
- thread.join()
223
-
224
  #### Generate without threading
 
 
 
 
 
 
 
 
 
 
 
 
225
  # with torch.no_grad():
226
  # outputs = model.generate(**generation_kwargs)
227
  # generated_tokens = outputs.sequences[0, input_ids.shape[1]:]
 
150
  input_ids = input_ids.to(model.device)
151
 
152
 
153
+ streamer = CustomStreamer(skip_prompt=True, timeout=1000.0)
154
 
155
  def generate_with_states():
156
  with torch.no_grad():
 
173
  se_highlighted_text = ""
174
  acc_highlighted_text = ""
175
  for token_id in streamer:
176
+ print(token_id)
177
  hidden_states = streamer.hidden_states_queue.get()
178
  if hidden_states is streamer.stop_signal:
179
  break
180
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  # Semantic Uncertainty Probe
182
  token_embeddings = torch.stack([generated_token[0, 0, :].cpu() for generated_token in hidden_states]).numpy() # (num_layers, hidden_size)
183
  se_concat_layers = token_embeddings[se_layer_range[0]:se_layer_range[1]].reshape(-1)
 
198
 
199
  yield se_highlighted_text, acc_highlighted_text
200
 
 
 
201
  #### Generate without threading
202
+ # generation_kwargs = dict(
203
+ # input_ids=input_ids,
204
+ # max_new_tokens=max_new_tokens,
205
+ # do_sample=True,
206
+ # top_p=top_p,
207
+ # top_k=top_k,
208
+ # temperature=temperature,
209
+ # repetition_penalty=repetition_penalty,
210
+ # streamer=streamer,
211
+ # output_hidden_states=True,
212
+ # return_dict_in_generate=True,
213
+ # )
214
  # with torch.no_grad():
215
  # outputs = model.generate(**generation_kwargs)
216
  # generated_tokens = outputs.sequences[0, input_ids.shape[1]:]