Spaces:
Runtime error
Runtime error
Added automatic resuming of sessions in Chat mode (resends the context to the API).
Browse files- app.py +1 -1
- chat.py +71 -35
- chat_client.py +13 -7
app.py
CHANGED
@@ -18,4 +18,4 @@ with gr.Blocks() as iface:
|
|
18 |
|
19 |
# Queues are required to enable generators
|
20 |
iface.queue(concurrency_count=5, max_size=50)
|
21 |
-
iface.launch()
|
|
|
18 |
|
19 |
# Queues are required to enable generators
|
20 |
iface.queue(concurrency_count=5, max_size=50)
|
21 |
+
iface.launch(show_error=True)
|
chat.py
CHANGED
@@ -10,37 +10,54 @@ CHAT_URL='ws://chat.petals.ml/api/v2/generate'
|
|
10 |
|
11 |
EMPTY_STATE = {
|
12 |
'generate': False,
|
|
|
|
|
13 |
'history': [],
|
14 |
}
|
15 |
|
16 |
-
def generate(state, *args):
|
17 |
# Save that we're in generating loop
|
18 |
state['generate'] = True
|
19 |
|
20 |
try:
|
21 |
-
for x in _generate(state, *args):
|
22 |
yield x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
finally:
|
24 |
state['generate'] = False
|
25 |
|
26 |
-
def _generate(state, prompt, model, endseq, max_length,
|
27 |
-
do_sample, top_k, top_p, temperature
|
28 |
-
context):
|
29 |
|
|
|
30 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
31 |
|
32 |
-
|
33 |
-
client
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
context += eos
|
42 |
-
for question, answer in state['history']:
|
43 |
-
|
44 |
|
45 |
# Fix eventual eos token mismatch and add eos token to context and prompt
|
46 |
if "bloomz" in model:
|
@@ -50,7 +67,7 @@ def _generate(state, prompt, model, endseq, max_length,
|
|
50 |
context = context.replace("</s>", eos)
|
51 |
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
52 |
|
53 |
-
prompt2 = f"{context}Human: {prompt2}AI:
|
54 |
|
55 |
# Translate checkbox items to actual sequences
|
56 |
seq = []
|
@@ -79,12 +96,13 @@ def _generate(state, prompt, model, endseq, max_length,
|
|
79 |
if temperature == 0:
|
80 |
temperature = 1.0
|
81 |
|
|
|
|
|
82 |
# Update widgets even before we get the first response
|
83 |
-
yield state, state['history'] + [[prompt, '']], None,
|
84 |
|
85 |
-
output = ''
|
86 |
-
output_raw = ''
|
87 |
orig_history = state['history']
|
|
|
88 |
try:
|
89 |
for out in client.generate(prompt2,
|
90 |
max_new_tokens=1,
|
@@ -97,36 +115,53 @@ def _generate(state, prompt, model, endseq, max_length,
|
|
97 |
|
98 |
if not state['generate']:
|
99 |
client.close_session()
|
100 |
-
yield state, [], None, ''
|
|
|
101 |
return
|
102 |
|
103 |
-
|
104 |
-
output += out
|
105 |
|
106 |
# Detect end sequences and finish the generation
|
107 |
# prematurely if found.
|
108 |
for s in seq:
|
109 |
-
spl =
|
110 |
-
|
111 |
if len(spl) > 1:
|
112 |
-
state['history'] = orig_history + [[prompt,
|
113 |
-
|
|
|
|
|
114 |
return
|
115 |
|
116 |
# Keep original history untouched as we're adding just
|
117 |
# a chunks at one moment.
|
118 |
-
state['history'] = orig_history + [[prompt,
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
except Exception:
|
|
|
|
|
|
|
|
|
122 |
print(traceback.format_exc())
|
123 |
-
|
|
|
|
|
124 |
return
|
125 |
|
126 |
def reset(state):
|
127 |
"""Resets the session and clears the chat window."""
|
128 |
state.update(EMPTY_STATE)
|
129 |
-
return state, [], ''
|
130 |
|
131 |
with gr.Blocks() as iface_chat:
|
132 |
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
@@ -163,6 +198,7 @@ with gr.Blocks() as iface_chat:
|
|
163 |
chat = gr.Chatbot(label='Chat window')
|
164 |
prompt = gr.Textbox(show_label=False, label='Prompt',
|
165 |
placeholder="Prompt Here and press Enter...").style(container=False)
|
|
|
166 |
|
167 |
with gr.Row():
|
168 |
button_generate = gr.Button("Generate")
|
@@ -174,13 +210,13 @@ with gr.Blocks() as iface_chat:
|
|
174 |
# Chat history
|
175 |
state = gr.State(EMPTY_STATE)
|
176 |
|
177 |
-
inputs = [state, prompt, model,
|
178 |
-
|
179 |
-
outputs=[state, chat, prompt, output]
|
180 |
|
181 |
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
182 |
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
183 |
-
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
|
184 |
|
185 |
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
|
186 |
examples=[
|
|
|
10 |
|
11 |
EMPTY_STATE = {
|
12 |
'generate': False,
|
13 |
+
'model': None,
|
14 |
+
'client': None,
|
15 |
'history': [],
|
16 |
}
|
17 |
|
18 |
+
def generate(state, prompt, model, context, output, *args):
|
19 |
# Save that we're in generating loop
|
20 |
state['generate'] = True
|
21 |
|
22 |
try:
|
23 |
+
for x in _generate(state, prompt, model, context, output, *args):
|
24 |
yield x
|
25 |
+
except BrokenPipeError:
|
26 |
+
# Broken session, try to renew
|
27 |
+
# TODO This is a bit fragile because of recursive call...
|
28 |
+
print("Retrying session...")
|
29 |
+
context = output
|
30 |
+
output = ''
|
31 |
+
yield from generate(state, prompt, model, context, output, *args)
|
32 |
finally:
|
33 |
state['generate'] = False
|
34 |
|
35 |
+
def _generate(state, prompt, model, context, output, endseq, max_length,
|
36 |
+
do_sample, top_k, top_p, temperature):
|
|
|
37 |
|
38 |
+
print('prompt', prompt)
|
39 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
40 |
|
41 |
+
if state['model'] != model or \
|
42 |
+
state['client'] == None or state['client'].is_session() == False:
|
43 |
+
|
44 |
+
try:
|
45 |
+
state['client'] = chat_client.ModelClient(CHAT_URL)
|
46 |
+
state['client'].open_session(f"bigscience/{model}-petals", max_length)
|
47 |
+
state['model'] = model
|
48 |
+
except Exception:
|
49 |
+
print(traceback.format_exc())
|
50 |
+
yield state, state['history'], prompt, output, \
|
51 |
+
gr.update(visible=True, value=traceback.format_exc())
|
52 |
+
return
|
53 |
+
else:
|
54 |
+
context = ''
|
55 |
+
|
56 |
+
client = state['client']
|
57 |
|
58 |
context += eos
|
59 |
+
#for question, answer in state['history']:
|
60 |
+
# context += f"Human: {question}{eos}AI: {answer}{eos}"
|
61 |
|
62 |
# Fix eventual eos token mismatch and add eos token to context and prompt
|
63 |
if "bloomz" in model:
|
|
|
67 |
context = context.replace("</s>", eos)
|
68 |
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
69 |
|
70 |
+
prompt2 = f"{context}Human: {prompt2}AI:"
|
71 |
|
72 |
# Translate checkbox items to actual sequences
|
73 |
seq = []
|
|
|
96 |
if temperature == 0:
|
97 |
temperature = 1.0
|
98 |
|
99 |
+
output += prompt2
|
100 |
+
|
101 |
# Update widgets even before we get the first response
|
102 |
+
yield state, state['history'] + [[prompt, '']], None, output, gr.update(visible=False)
|
103 |
|
|
|
|
|
104 |
orig_history = state['history']
|
105 |
+
new_line = ''
|
106 |
try:
|
107 |
for out in client.generate(prompt2,
|
108 |
max_new_tokens=1,
|
|
|
115 |
|
116 |
if not state['generate']:
|
117 |
client.close_session()
|
118 |
+
yield state, [], None, '', ''
|
119 |
+
# Stopping generation
|
120 |
return
|
121 |
|
122 |
+
new_line += out
|
|
|
123 |
|
124 |
# Detect end sequences and finish the generation
|
125 |
# prematurely if found.
|
126 |
for s in seq:
|
127 |
+
spl = new_line.split(s)
|
128 |
+
new_line = spl[0]
|
129 |
if len(spl) > 1:
|
130 |
+
state['history'] = orig_history + [[prompt, new_line]]
|
131 |
+
output += new_line
|
132 |
+
yield state, state['history'], None, output, ''
|
133 |
+
# Stopping generation
|
134 |
return
|
135 |
|
136 |
# Keep original history untouched as we're adding just
|
137 |
# a chunks at one moment.
|
138 |
+
state['history'] = orig_history + [[prompt, new_line]]
|
139 |
+
yield state, state['history'], None, output, ''
|
140 |
+
|
141 |
+
except BrokenPipeError:
|
142 |
+
# Session was interrupted
|
143 |
+
# Handled in upstream func
|
144 |
+
client.close_session()
|
145 |
+
state['client'] = None
|
146 |
+
state['model'] = None
|
147 |
+
|
148 |
+
print("Broken session!")
|
149 |
+
raise
|
150 |
except Exception:
|
151 |
+
client.close_session()
|
152 |
+
state['client'] = None
|
153 |
+
state['model'] = None
|
154 |
+
|
155 |
print(traceback.format_exc())
|
156 |
+
# TODO Store errors outside output log
|
157 |
+
yield state, state['history'], prompt, output, \
|
158 |
+
gr.update(visible=True, value=traceback.format_exc())
|
159 |
return
|
160 |
|
161 |
def reset(state):
|
162 |
"""Resets the session and clears the chat window."""
|
163 |
state.update(EMPTY_STATE)
|
164 |
+
return state, [], '', gr.update(visible=False, value='')
|
165 |
|
166 |
with gr.Blocks() as iface_chat:
|
167 |
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
|
|
198 |
chat = gr.Chatbot(label='Chat window')
|
199 |
prompt = gr.Textbox(show_label=False, label='Prompt',
|
200 |
placeholder="Prompt Here and press Enter...").style(container=False)
|
201 |
+
error = gr.Textbox(label="Error log", visible=False, elem_id="error")
|
202 |
|
203 |
with gr.Row():
|
204 |
button_generate = gr.Button("Generate")
|
|
|
210 |
# Chat history
|
211 |
state = gr.State(EMPTY_STATE)
|
212 |
|
213 |
+
inputs = [state, prompt, model, context, output, endseq,
|
214 |
+
max_length, do_sample, top_k, top_p, temperature]
|
215 |
+
outputs=[state, chat, prompt, output, error]
|
216 |
|
217 |
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
218 |
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
219 |
+
button_reset.click(reset, inputs=[state], outputs=[state, chat, output, error])
|
220 |
|
221 |
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
|
222 |
examples=[
|
chat_client.py
CHANGED
@@ -22,29 +22,35 @@ class ModelClient(object):
|
|
22 |
self.ws.send(json.dumps(payload))
|
23 |
assert json.loads(self.ws.recv())['ok'] == True
|
24 |
|
|
|
|
|
|
|
25 |
def close_session(self):
|
26 |
if self.ws:
|
27 |
self.ws.close()
|
|
|
28 |
|
29 |
def generate(self, prompt, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
payload = {
|
31 |
"type": "generate",
|
32 |
"inputs": prompt,
|
33 |
"max_new_tokens": 1,
|
34 |
"do_sample": 0,
|
35 |
-
"temperature":
|
36 |
"stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
|
37 |
}
|
38 |
payload = {**payload, **kwargs}
|
39 |
self.ws.send(json.dumps(payload))
|
40 |
|
41 |
while True:
|
42 |
-
|
43 |
-
data = json.loads(self.ws.recv())
|
44 |
-
except json.decoder.JSONDecodeError:
|
45 |
-
self.close_session()
|
46 |
-
raise
|
47 |
-
|
48 |
if not data['ok']:
|
49 |
raise Exception(data['traceback'])
|
50 |
yield data['outputs']
|
|
|
22 |
self.ws.send(json.dumps(payload))
|
23 |
assert json.loads(self.ws.recv())['ok'] == True
|
24 |
|
25 |
+
def is_session(self):
|
26 |
+
return self.ws != None
|
27 |
+
|
28 |
def close_session(self):
|
29 |
if self.ws:
|
30 |
self.ws.close()
|
31 |
+
self.ws = None
|
32 |
|
33 |
def generate(self, prompt, **kwargs):
|
34 |
+
try:
|
35 |
+
return self._generate(prompt, **kwargs)
|
36 |
+
except:
|
37 |
+
self.close_session()
|
38 |
+
raise
|
39 |
+
|
40 |
+
def _generate(self, prompt, **kwargs):
|
41 |
payload = {
|
42 |
"type": "generate",
|
43 |
"inputs": prompt,
|
44 |
"max_new_tokens": 1,
|
45 |
"do_sample": 0,
|
46 |
+
"temperature": 1,
|
47 |
"stop_sequence": "</s>" if "bloomz" in self.model else "\n\n",
|
48 |
}
|
49 |
payload = {**payload, **kwargs}
|
50 |
self.ws.send(json.dumps(payload))
|
51 |
|
52 |
while True:
|
53 |
+
data = json.loads(self.ws.recv())
|
|
|
|
|
|
|
|
|
|
|
54 |
if not data['ok']:
|
55 |
raise Exception(data['traceback'])
|
56 |
yield data['outputs']
|