Spaces:
Paused
Paused
Update app_chat.py
Browse files- app_chat.py +4 -12
app_chat.py
CHANGED
@@ -10,8 +10,6 @@ from transformers import StoppingCriteria, StoppingCriteriaList, StopStringCrite
|
|
10 |
|
11 |
import subprocess
|
12 |
|
13 |
-
global model
|
14 |
-
|
15 |
import torch._dynamo
|
16 |
torch._dynamo.config.suppress_errors = True
|
17 |
|
@@ -58,8 +56,6 @@ def generate(
|
|
58 |
repetition_penalty: float = 1.2,
|
59 |
) -> Iterator[str]:
|
60 |
conversation = []
|
61 |
-
|
62 |
-
global model
|
63 |
|
64 |
if system_prompt:
|
65 |
conversation.append({"role": "system", "content": system_prompt})
|
@@ -96,14 +92,10 @@ def generate(
|
|
96 |
t.start()
|
97 |
|
98 |
outputs = []
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
except:
|
104 |
-
print("restarting the model, got some error")
|
105 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="cuda", trust_remote_code=True)
|
106 |
-
model = model.cuda().to(torch.bfloat16)
|
107 |
|
108 |
chat_interface = gr.ChatInterface(
|
109 |
fn=generate,
|
|
|
10 |
|
11 |
import subprocess
|
12 |
|
|
|
|
|
13 |
import torch._dynamo
|
14 |
torch._dynamo.config.suppress_errors = True
|
15 |
|
|
|
56 |
repetition_penalty: float = 1.2,
|
57 |
) -> Iterator[str]:
|
58 |
conversation = []
|
|
|
|
|
59 |
|
60 |
if system_prompt:
|
61 |
conversation.append({"role": "system", "content": system_prompt})
|
|
|
92 |
t.start()
|
93 |
|
94 |
outputs = []
|
95 |
+
for text in streamer:
|
96 |
+
outputs.append(text)
|
97 |
+
yield "".join(outputs)
|
98 |
+
|
|
|
|
|
|
|
|
|
99 |
|
100 |
chat_interface = gr.ChatInterface(
|
101 |
fn=generate,
|