Spaces:
Runtime error
Runtime error
Adding chat mode.
Browse files
app.py
CHANGED
@@ -1,6 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from prompt import iface_prompt
|
|
|
4 |
|
5 |
with gr.Blocks() as iface:
|
6 |
gr.Markdown("""# Petals playground
|
@@ -13,7 +14,7 @@ with gr.Blocks() as iface:
|
|
13 |
|
14 |
BLOOMZ performs better in chat mode and understands the instructions better.""")
|
15 |
|
16 |
-
gr.TabbedInterface([iface_prompt, ], ["Prompt mode",])
|
17 |
|
18 |
# Queues are required to enable generators
|
19 |
iface.queue(concurrency_count=5)
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
from prompt import iface_prompt
|
4 |
+
from chat import iface_chat
|
5 |
|
6 |
with gr.Blocks() as iface:
|
7 |
gr.Markdown("""# Petals playground
|
|
|
14 |
|
15 |
BLOOMZ performs better in chat mode and understands the instructions better.""")
|
16 |
|
17 |
+
gr.TabbedInterface([iface_prompt, iface_chat], ["Prompt mode", "Chat mode"])
|
18 |
|
19 |
# Queues are required to enable generators
|
20 |
iface.queue(concurrency_count=5)
|
chat.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# or gradio app.py
|
3 |
+
|
4 |
+
import traceback
|
5 |
+
import gradio as gr
|
6 |
+
import chat_client
|
7 |
+
|
8 |
+
CHAT_URL='ws://chat.petals.ml/api/v2/generate'
|
9 |
+
#CHAT_URL='ws://localhost:8000/api/v2/generate'
|
10 |
+
|
11 |
+
def generate(prompt, model, endseq, max_length,
|
12 |
+
do_sample, top_k, top_p, temperature,
|
13 |
+
context, state):
|
14 |
+
|
15 |
+
eos = "</s>\n" if "bloomz" in model else "\n\n"
|
16 |
+
|
17 |
+
try:
|
18 |
+
client = chat_client.ModelClient(CHAT_URL)
|
19 |
+
client.open_session(f"bigscience/{model}-petals", max_length)
|
20 |
+
except Exception:
|
21 |
+
print(traceback.format_exc())
|
22 |
+
yield state, state, prompt, "Error: " + traceback.format_exc()
|
23 |
+
return
|
24 |
+
|
25 |
+
context += eos
|
26 |
+
for question, answer in state:
|
27 |
+
context += f"Human: {question}{eos}AI: {answer}{eos}"
|
28 |
+
|
29 |
+
# Fix eventual eos token mismatch and add eos token to context and prompt
|
30 |
+
if "bloomz" in model:
|
31 |
+
context = context.replace("\n\n", eos)
|
32 |
+
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
|
33 |
+
else:
|
34 |
+
context = context.replace("</s>", eos)
|
35 |
+
prompt2 = prompt.replace("</s>", eos) + "\n\n"
|
36 |
+
|
37 |
+
prompt2 = f"{context}Human: {prompt2}AI: "
|
38 |
+
|
39 |
+
# Translate checkbox items to actual sequences
|
40 |
+
seq = []
|
41 |
+
for s in endseq:
|
42 |
+
if s == "Human:":
|
43 |
+
seq.append("Human:")
|
44 |
+
if s == "AI:":
|
45 |
+
seq.append("AI:")
|
46 |
+
if s == "\\n":
|
47 |
+
seq.append("\n")
|
48 |
+
elif s == "</s>":
|
49 |
+
seq.append("</s>")
|
50 |
+
elif s == "? (question mark)":
|
51 |
+
seq.append("?")
|
52 |
+
elif s == ". (dot)":
|
53 |
+
seq.append(".")
|
54 |
+
|
55 |
+
# only top_k or top_p can be set
|
56 |
+
if top_k == 0:
|
57 |
+
top_k = None
|
58 |
+
if top_p == 0:
|
59 |
+
top_p = None
|
60 |
+
if top_p and top_k:
|
61 |
+
top_k = None
|
62 |
+
|
63 |
+
if temperature == 0:
|
64 |
+
temperature = 1.0
|
65 |
+
|
66 |
+
# Update widgets even before we get the first response
|
67 |
+
yield state + [[prompt, '']], state, None, prompt2
|
68 |
+
|
69 |
+
output = ''
|
70 |
+
output_raw = ''
|
71 |
+
try:
|
72 |
+
for out in client.generate(prompt2,
|
73 |
+
max_new_tokens=1,
|
74 |
+
do_sample=do_sample,
|
75 |
+
temperature=temperature,
|
76 |
+
top_k=top_k,
|
77 |
+
top_p=top_p,
|
78 |
+
extra_stop_sequences=seq
|
79 |
+
):
|
80 |
+
|
81 |
+
output_raw += out
|
82 |
+
output += out
|
83 |
+
|
84 |
+
# Detect end sequences and finish the generation
|
85 |
+
# prematurely if found.
|
86 |
+
for s in seq:
|
87 |
+
spl = output.split(s)
|
88 |
+
output = spl[0]
|
89 |
+
if len(spl) > 1:
|
90 |
+
state2 = state + [[prompt, output]]
|
91 |
+
yield state2, state2, None, prompt2 + output_raw
|
92 |
+
return
|
93 |
+
|
94 |
+
state2 = state + [[prompt, output]]
|
95 |
+
yield state2, state2, None, prompt2 + output_raw
|
96 |
+
except Exception:
|
97 |
+
print(traceback.format_exc())
|
98 |
+
yield state, state, prompt, "Error: " + traceback.format_exc()
|
99 |
+
return
|
100 |
+
|
101 |
+
with gr.Blocks() as iface_chat:
|
102 |
+
gr.Markdown("""**Let's talk to Bloom in a chat!**""")
|
103 |
+
|
104 |
+
with gr.Row():
|
105 |
+
model = gr.Radio(["bloom", "bloomz", "bloom-7b1"], value='bloomz', label="Use model")
|
106 |
+
|
107 |
+
# Additional ending sequence, at which generation shoud stop
|
108 |
+
endseq = gr.CheckboxGroup(["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
|
109 |
+
value=["Human:", "AI:", "\\n", "</s>"], label='Extra end sequences')
|
110 |
+
|
111 |
+
# Maximum length of inference session
|
112 |
+
max_length = gr.Radio([64, 128, 256, 512, 1024, 2048], value=1024, interactive=True, label="Max length")
|
113 |
+
|
114 |
+
with gr.Row():
|
115 |
+
with gr.Column():
|
116 |
+
# Switch between sampling and greedy generation
|
117 |
+
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
|
118 |
+
context = gr.Textbox(lines=3, label="Initial context:", interactive=True,
|
119 |
+
value="A human talks to a powerful AI that follows the human's instructions.</s>\n"
|
120 |
+
"Human: Hi!</s>\n"
|
121 |
+
"AI: How can I help you?")
|
122 |
+
|
123 |
+
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
|
124 |
+
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
|
125 |
+
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
|
126 |
+
# TODO num_beams
|
127 |
+
|
128 |
+
# Generation temperature
|
129 |
+
temperature = gr.Number(value=0.75, precision=2, interactive=True, label="Temperature")
|
130 |
+
|
131 |
+
|
132 |
+
chat = gr.Chatbot(label='Chat window')
|
133 |
+
prompt = gr.Textbox(show_label=False,
|
134 |
+
placeholder="Prompt Here and press Enter...").style(container=False)
|
135 |
+
|
136 |
+
with gr.Row():
|
137 |
+
button_generate = gr.Button("Generate")
|
138 |
+
# button_clear = gr.Button("Clear session") # TODO
|
139 |
+
# button_stop = gr.Button("Stop") # TODO, not supported by websocket API yet.
|
140 |
+
|
141 |
+
output = gr.Textbox(lines=3, label='Raw Prompt Log')
|
142 |
+
|
143 |
+
# Chat history
|
144 |
+
state = gr.State([])
|
145 |
+
|
146 |
+
inputs = [prompt, model, endseq, max_length, do_sample,
|
147 |
+
top_k, top_p, temperature, context, state]
|
148 |
+
outputs=[chat, state, prompt, output]
|
149 |
+
|
150 |
+
prompt.submit(generate, inputs=inputs, outputs=outputs)
|
151 |
+
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
152 |
+
|
153 |
+
examples = gr.Examples(inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
|
154 |
+
examples=[
|
155 |
+
["A human talks to a powerful AI that follows the human's instructions.</s>\n"
|
156 |
+
"Human: Hi!</s>\n"
|
157 |
+
"AI: Hi! How can I help you?",
|
158 |
+
"Could you remind me please what's the capital of Portugal?",
|
159 |
+
"bloomz", True, 0, 0.9, 0.75]
|
160 |
+
])
|
prompt.py
CHANGED
@@ -43,6 +43,9 @@ def generate(prompt, model, endseq, max_length,
|
|
43 |
if top_p and top_k:
|
44 |
top_k = None
|
45 |
|
|
|
|
|
|
|
46 |
prompt2 = prompt
|
47 |
output = ''
|
48 |
|
@@ -110,8 +113,10 @@ with gr.Blocks() as iface_prompt:
|
|
110 |
|
111 |
output = gr.Textbox(lines=3, label='Output')
|
112 |
|
113 |
-
|
114 |
-
|
|
|
|
|
115 |
|
116 |
examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
|
117 |
examples=[
|
|
|
43 |
if top_p and top_k:
|
44 |
top_k = None
|
45 |
|
46 |
+
if not temperature:
|
47 |
+
temperature = 1.0
|
48 |
+
|
49 |
prompt2 = prompt
|
50 |
output = ''
|
51 |
|
|
|
113 |
|
114 |
output = gr.Textbox(lines=3, label='Output')
|
115 |
|
116 |
+
inputs = [prompt, model, endseq, max_length, do_sample,
|
117 |
+
top_k, top_p, temperature, add_stoptoken, copy_output]
|
118 |
+
outputs = [prompt, output]
|
119 |
+
button_generate.click(generate, inputs=inputs, outputs=outputs)
|
120 |
|
121 |
examples = gr.Examples(inputs=[prompt, model, do_sample, top_k, top_p, temperature, add_stoptoken],
|
122 |
examples=[
|