Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -3,9 +3,9 @@ import torch
|
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
5 |
import spaces
|
6 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer
|
7 |
-
from huggingface_hub.inference._generated.types import TextGenerationStreamOutput, TextGenerationStreamOutputToken
|
8 |
import os
|
|
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
|
11 |
|
@@ -109,35 +109,28 @@ def stream_chat(message, history: list, system: str, temperature: float, max_new
|
|
109 |
return_tensors="pt"
|
110 |
).to(model.device)
|
111 |
images = None
|
112 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
113 |
|
114 |
generate_kwargs = dict(
|
115 |
input_ids=input_ids,
|
116 |
-
streamer=streamer,
|
117 |
max_new_tokens=max_new_tokens,
|
118 |
temperature=temperature,
|
119 |
do_sample=True,
|
|
|
120 |
eos_token_id=terminators,
|
121 |
images=images
|
122 |
)
|
123 |
if temperature == 0:
|
124 |
generate_kwargs["do_sample"] = False
|
125 |
|
126 |
-
|
127 |
-
t.start()
|
128 |
input_token_len = input_ids.shape[1]
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
logprob=0,
|
137 |
-
text=next_text,
|
138 |
-
special=False,
|
139 |
-
)
|
140 |
-
)
|
141 |
|
142 |
|
143 |
chatbot = gr.Chatbot(height=450)
|
|
|
3 |
from PIL import Image
|
4 |
import gradio as gr
|
5 |
import spaces
|
6 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
7 |
import os
|
8 |
+
import time
|
9 |
from huggingface_hub import hf_hub_download
|
10 |
|
11 |
|
|
|
109 |
return_tensors="pt"
|
110 |
).to(model.device)
|
111 |
images = None
|
|
|
112 |
|
113 |
generate_kwargs = dict(
|
114 |
input_ids=input_ids,
|
|
|
115 |
max_new_tokens=max_new_tokens,
|
116 |
temperature=temperature,
|
117 |
do_sample=True,
|
118 |
+
num_beams=1,
|
119 |
eos_token_id=terminators,
|
120 |
images=images
|
121 |
)
|
122 |
if temperature == 0:
|
123 |
generate_kwargs["do_sample"] = False
|
124 |
|
125 |
+
output_ids=model.generate(**generate_kwargs)
|
|
|
126 |
input_token_len = input_ids.shape[1]
|
127 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
128 |
+
outputs = outputs.strip()
|
129 |
+
|
130 |
+
for i in range(len(outputs)):
|
131 |
+
time.sleep(0.05)
|
132 |
+
yield outputs[: i + 1]
|
133 |
+
|
|
|
|
|
|
|
|
|
|
|
134 |
|
135 |
|
136 |
chatbot = gr.Chatbot(height=450)
|