Spaces:
Runtime error
Runtime error
Improved switching between models.
Browse files
chat.py
CHANGED
@@ -5,6 +5,7 @@ import traceback
|
|
5 |
import gradio as gr
|
6 |
import chat_client
|
7 |
import json
|
|
|
8 |
|
9 |
CHAT_URL='ws://chat.petals.ml/api/v2/generate'
|
10 |
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
@@ -38,6 +39,12 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
38 |
print('prompt', prompt)
|
39 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
if state['model'] != model or \
|
42 |
state['client'] == None or state['client'].is_session() == False:
|
43 |
|
@@ -48,6 +55,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
48 |
except Exception:
|
49 |
print(traceback.format_exc())
|
50 |
raise gr.Error(traceback.format_exc())
|
|
|
51 |
else:
|
52 |
context = ''
|
53 |
|
@@ -63,6 +71,7 @@ def _generate(state, prompt, model, context, output, endseq, max_length,
|
|
63 |
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
|
64 |
else:
|
65 |
context = context.replace("</s>", eos)
|
|
|
66 |
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
67 |
|
68 |
prompt2 = f"{context}Human: {prompt2}AI:"
|
|
|
5 |
import gradio as gr
|
6 |
import chat_client
|
7 |
import json
|
8 |
+
import re
|
9 |
|
10 |
CHAT_URL='ws://chat.petals.ml/api/v2/generate'
|
11 |
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
|
|
39 |
print('prompt', prompt)
|
40 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
41 |
|
42 |
+
if state['model'] != model and output:
|
43 |
+
# If the connection is resumed, output is truncated in generate().
|
44 |
+
# So this happen when user change model.
|
45 |
+
context = output
|
46 |
+
output = ''
|
47 |
+
|
48 |
if state['model'] != model or \
|
49 |
state['client'] == None or state['client'].is_session() == False:
|
50 |
|
|
|
55 |
except Exception:
|
56 |
print(traceback.format_exc())
|
57 |
raise gr.Error(traceback.format_exc())
|
58 |
+
|
59 |
else:
|
60 |
context = ''
|
61 |
|
|
|
71 |
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
|
72 |
else:
|
73 |
context = context.replace("</s>", eos)
|
74 |
+
context = re.sub(r"\n\n+", "\n\n", context)
|
75 |
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
76 |
|
77 |
prompt2 = f"{context}Human: {prompt2}AI:"
|