Spaces:
Runtime error
Runtime error
Adding session reset button to Chat mode.
Browse files
app.py
CHANGED
@@ -17,5 +17,5 @@ with gr.Blocks() as iface:
|
|
17 |
gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
|
18 |
|
19 |
# Queues are required to enable generators
|
20 |
-
iface.queue(concurrency_count=5)
|
21 |
iface.launch()
|
|
|
17 |
gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
|
18 |
|
19 |
# Queues are required to enable generators
|
20 |
+
iface.queue(concurrency_count=5, max_size=50)
|
21 |
iface.launch()
|
chat.py
CHANGED
@@ -8,22 +8,38 @@ import chat_client
|
|
8 |
CHAT_URL='ws://chat.petals.ml/api/v2/generate'
|
9 |
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
do_sample, top_k, top_p, temperature,
|
13 |
-
context
|
14 |
|
15 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
16 |
|
17 |
try:
|
18 |
client = chat_client.ModelClient(CHAT_URL)
|
19 |
client.open_session(f"bigscience/{model}-petals", max_length)
|
|
|
20 |
except Exception:
|
21 |
print(traceback.format_exc())
|
22 |
-
yield state, state, prompt, "Error: " + traceback.format_exc()
|
23 |
return
|
24 |
|
25 |
context += eos
|
26 |
-
for question, answer in state:
|
27 |
context += f"Human: {question}{eos}AI: {answer}{eos}"
|
28 |
|
29 |
# Fix eventual eos token mismatch and add eos token to context and prompt
|
@@ -64,10 +80,11 @@ def generate(prompt, model, endseq, max_length,
|
|
64 |
temperature = 1.0
|
65 |
|
66 |
# Update widgets even before we get the first response
|
67 |
-
yield state + [[prompt, '']],
|
68 |
|
69 |
output = ''
|
70 |
output_raw = ''
|
|
|
71 |
try:
|
72 |
for out in client.generate(prompt2,
|
73 |
max_new_tokens=1,
|
@@ -78,6 +95,13 @@ def generate(prompt, model, endseq, max_length,
|
|
78 |
extra_stop_sequences=seq
|
79 |
):
|
80 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
output_raw += out
|
82 |
output += out
|
83 |
|
@@ -87,17 +111,25 @@ def generate(prompt, model, endseq, max_length,
|
|
87 |
spl = output.split(s)
|
88 |
output = spl[0]
|
89 |
if len(spl) > 1:
|
90 |
-
|
91 |
-
yield
|
92 |
return
|
93 |
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
96 |
except Exception:
|
97 |
print(traceback.format_exc())
|
98 |
-
yield state, state, prompt, "
|
99 |
return
|
100 |
|
|
|
|
|
|
|
|
|
|
|
101 |
with gr.Blocks() as iface_chat:
|
102 |
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
103 |
|
@@ -116,7 +148,7 @@ with gr.Blocks() as iface_chat:
|
|
116 |
# Switch between sampling and greedy generation
|
117 |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
118 |
context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
|
119 |
-
value="A human talks to a powerful AI that follows the human's instructions
|
120 |
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
|
121 |
"Human: Hi!</s>\n"
|
122 |
"AI: How can I help you?")
|
@@ -136,21 +168,22 @@ with gr.Blocks() as iface_chat:
|
|
136 |
|
137 |
with gr.Row():
|
138 |
button_generate = gr.Button("Generate")
|
139 |
-
|
140 |
# button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
|
141 |
|
142 |
with gr.Accordion("Raw prompt log", open=False):
|
143 |
output = gr.Textbox(lines=3, show_label=False).style(container=False)
|
144 |
|
145 |
# Chat history
|
146 |
-
state = gr.State(
|
147 |
|
148 |
-
inputs = [prompt, model, endseq, max_length, do_sample,
|
149 |
-
top_k, top_p, temperature, context
|
150 |
-
outputs=[
|
151 |
|
152 |
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
153 |
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
|
|
154 |
|
155 |
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
|
156 |
examples=[
|
|
|
8 |
CHAT_URL='ws://chat.petals.ml/api/v2/generate'
|
9 |
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
10 |
|
11 |
+
EMPTY_STATE = {
|
12 |
+
'generate': False,
|
13 |
+
'history': [],
|
14 |
+
}
|
15 |
+
|
16 |
+
def generate(state, *args):
|
17 |
+
# Save that we're in generating loop
|
18 |
+
state['generate'] = True
|
19 |
+
|
20 |
+
try:
|
21 |
+
for x in _generate(state, *args):
|
22 |
+
yield x
|
23 |
+
finally:
|
24 |
+
state['generate'] = False
|
25 |
+
|
26 |
+
def _generate(state, prompt, model, endseq, max_length,
|
27 |
do_sample, top_k, top_p, temperature,
|
28 |
+
context):
|
29 |
|
30 |
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
31 |
|
32 |
try:
|
33 |
client = chat_client.ModelClient(CHAT_URL)
|
34 |
client.open_session(f"bigscience/{model}-petals", max_length)
|
35 |
+
state['client'] = client
|
36 |
except Exception:
|
37 |
print(traceback.format_exc())
|
38 |
+
yield state, state['history'], prompt, "Error: " + traceback.format_exc()
|
39 |
return
|
40 |
|
41 |
context += eos
|
42 |
+
for question, answer in state['history']:
|
43 |
context += f"Human: {question}{eos}AI: {answer}{eos}"
|
44 |
|
45 |
# Fix eventual eos token mismatch and add eos token to context and prompt
|
|
|
80 |
temperature = 1.0
|
81 |
|
82 |
# Update widgets even before we get the first response
|
83 |
+
yield state, state['history'] + [[prompt, '']], None, prompt2
|
84 |
|
85 |
output = ''
|
86 |
output_raw = ''
|
87 |
+
orig_history = state['history']
|
88 |
try:
|
89 |
for out in client.generate(prompt2,
|
90 |
max_new_tokens=1,
|
|
|
95 |
extra_stop_sequences=seq
|
96 |
):
|
97 |
|
98 |
+
if not state['generate']:
|
99 |
+
print("Stopping generation.")
|
100 |
+
client.close_session()
|
101 |
+
yield state, [], None, ''
|
102 |
+
return
|
103 |
+
#return state, state['history'], None, prompt2 + output_raw
|
104 |
+
|
105 |
output_raw += out
|
106 |
output += out
|
107 |
|
|
|
111 |
spl = output.split(s)
|
112 |
output = spl[0]
|
113 |
if len(spl) > 1:
|
114 |
+
state['history'] = orig_history + [[prompt, output]]
|
115 |
+
yield state, state['history'], None, prompt2 + output_raw
|
116 |
return
|
117 |
|
118 |
+
# Keep original history untouched as we're adding just
|
119 |
+
# a chunks at one moment.
|
120 |
+
state['history'] = orig_history + [[prompt, output]]
|
121 |
+
|
122 |
+
yield state, state['history'], None, prompt2 + output_raw
|
123 |
except Exception:
|
124 |
print(traceback.format_exc())
|
125 |
+
yield state, state['history'], prompt, output_raw + "\nError: " + traceback.format_exc()
|
126 |
return
|
127 |
|
128 |
+
def reset(state):
|
129 |
+
"""Resets the session and clears the chat window."""
|
130 |
+
state.update(EMPTY_STATE)
|
131 |
+
return state, [], ''
|
132 |
+
|
133 |
with gr.Blocks() as iface_chat:
|
134 |
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
135 |
|
|
|
148 |
# Switch between sampling and greedy generation
|
149 |
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
150 |
context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
|
151 |
+
value="A human talks to a powerful AI that follows the human's instructions.\n"
|
152 |
"AI is talkative, friendly, positive and provides detailed answers to any question.</s>\n"
|
153 |
"Human: Hi!</s>\n"
|
154 |
"AI: How can I help you?")
|
|
|
168 |
|
169 |
with gr.Row():
|
170 |
button_generate = gr.Button("Generate")
|
171 |
+
button_reset = gr.Button("Reset/Clear session")
|
172 |
# button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
|
173 |
|
174 |
with gr.Accordion("Raw prompt log", open=False):
|
175 |
output = gr.Textbox(lines=3, show_label=False).style(container=False)
|
176 |
|
177 |
# Chat history
|
178 |
+
state = gr.State(EMPTY_STATE)
|
179 |
|
180 |
+
inputs = [state, prompt, model, endseq, max_length, do_sample,
|
181 |
+
top_k, top_p, temperature, context]
|
182 |
+
outputs=[state, chat, prompt, output]
|
183 |
|
184 |
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
185 |
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
186 |
+
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
|
187 |
|
188 |
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
|
189 |
examples=[
|