Spaces:
Runtime error
Runtime error
File size: 10,020 Bytes
b70859f 56e9926 104f494 b70859f 104f494 c461bd0 eec2373 104f494 f2475e8 b70859f f2475e8 b70859f 57a7522 f2475e8 b70859f f2475e8 e0349b7 57a7522 b70859f 57a7522 f2475e8 b70859f 104f494 2caf2e7 b70859f 2caf2e7 104f494 b70859f 8d5fa1d b70859f 8d5fa1d b70859f 8d5fa1d b70859f 57a7522 b70859f eec2373 b70859f 56e9926 8d5fa1d 57a7522 b70859f 104f494 b70859f 104f494 8d5fa1d 104f494 b70859f 57a7522 104f494 57a7522 b70859f 104f494 b70859f c461bd0 b70859f f2475e8 b70859f 57a7522 f2475e8 2caf2e7 57a7522 104f494 57a7522 104f494 b70859f 57a7522 b70859f 57a7522 104f494 b70859f f2475e8 b70859f 57a7522 fd37b55 2caf2e7 b70859f 2caf2e7 e0349b7 57a7522 b70859f 57a7522 56e9926 57a7522 b70859f 57a7522 56e9926 104f494 b70859f f2475e8 b70859f f2475e8 b70859f 104f494 c461bd0 104f494 b70859f c461bd0 b70859f 104f494 b70859f c461bd0 b70859f 104f494 b70859f 104f494 b70859f 104f494 b70859f 104f494 b70859f 104f494 dac87a3 104f494 2f09f5b 104f494 f2475e8 104f494 b70859f 104f494 e0349b7 104f494 b70859f 104f494 b70859f 8cd6256 c461bd0 b70859f c461bd0 b70859f 754225b c461bd0 754225b b70859f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
import json
import re
import time
from datetime import datetime
import gradio as gr
import chat_client
CHAT_URL = "wss://chat.petals.dev/api/v2/generate"
#CHAT_URL='ws://localhost:8000/api/v2/generate'
EMPTY_STATE = {
"generate": False,
"model": None,
"client": None,
"history": [],
}
def generate(state, prompt, model, context, output, *args):
# Save that we're in generating loop
state["generate"] = True
try:
yield from _generate(state, prompt, model, context, output, *args)
except (json.decoder.JSONDecodeError, BrokenPipeError):
# Broken session, try to renew
# TODO This is a bit fragile because of recursive call...
print("Retrying session...")
context = output
output = ""
yield from generate(state, prompt, model, context, output, *args)
finally:
state["generate"] = False
def _generate(
state,
prompt,
model,
context,
output,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
):
start = time.time()
cnt = 0 # Tokens generated
def stats():
# Produces inline stats for generation speed
if cnt == 0:
return "\u2026 | ? sec/t"
if cnt > time.time() - start:
items_per_sec = cnt / (time.time() - start)
return f" | {items_per_sec:.1f} t/sec"
sec_per_item = (time.time() - start) / cnt
return f" | {sec_per_item:.1f} sec/t"
eos = "</s>\n" if "bloomz" in model else "\n\n"
if state["model"] != model and output:
# If the connection is resumed, output is truncated in generate().
# So this executes when user change model.
context = output
output = ""
# Update widgets even before we get the first response
print("prompt", prompt)
yield state, state["history"] + [[prompt, stats()]], "", output
if (
state["model"] != model
or state["client"] == None
or state["client"].is_session() == False
):
try:
state["client"] = chat_client.ModelClient(CHAT_URL)
state["client"].open_session(model, max_length)
state["model"] = model
except Exception as e:
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
else:
context = ""
client = state["client"]
context += eos
# Fix eventual eos token mismatch and add eos token to context and prompt
if "bloomz" in model:
context = context.replace("\n\n", eos)
prompt2 = prompt.replace("\n\n", eos) + "</s>\n"
else:
context = context.replace("</s>", eos)
context = re.sub(r"\n\n+", "\n\n", context)
prompt2 = prompt.replace("</s>", eos) + "\n\n"
prompt2 = f"{context}Human: {prompt2}AI:"
# Translate checkbox items to actual sequences
seq = []
for s in endseq:
if s == "Human:":
seq.append("Human:")
if s == "AI:":
seq.append("AI:")
if s == "\\n":
seq.append("\n")
elif s == "</s>":
seq.append("</s>")
elif s == "? (question mark)":
seq.append("?")
elif s == ". (dot)":
seq.append(".")
# only top_k or top_p can be set
if top_k == 0:
top_k = None
if top_p == 0:
top_p = None
if top_p and top_k:
top_k = None
if temperature == 0:
temperature = 1.0
output += prompt2
orig_history = state["history"]
new_line = ""
try:
for out in client.generate(
prompt2,
max_new_tokens=1,
do_sample=do_sample,
temperature=temperature,
top_k=top_k,
top_p=top_p,
stop_sequences=seq,
):
if not state["generate"]:
client.close_session()
yield state, [], "", ""
# Stopping generation
return
cnt += 1
new_line += out
# Detect end sequences and finish the generation
# prematurely if found.
for s in seq:
spl = new_line.split(s)
new_line = spl[0]
if len(spl) > 1:
state["history"] = orig_history + [[prompt, new_line]]
output += new_line
yield state, state["history"], "", output
# Stopping generation
return
# Keep original history untouched as we're adding just
# a chunks at one moment.
state["history"] = orig_history + [[prompt, new_line + stats()]]
yield state, state["history"], "", output
# Avoid throwing an exception by generate()
# to prevent UI errors.
if cnt >= max_length - 6: # FIXME Bulgarian constant
break
# Final line w/o statistics
yield state, state["history"], "", output
except (json.decoder.JSONDecodeError, BrokenPipeError):
# Session was interrupted
# Handled in upstream func
client.close_session()
state["client"] = None
state["model"] = None
print("Broken session!")
raise
except Exception as e:
client.close_session()
state["client"] = None
state["model"] = None
print(datetime.now(), str(e)[-500:])
raise gr.Error(str(e)[-500:])
def reset(state):
"""Resets the session and clears the chat window."""
state.update(EMPTY_STATE)
return state, [], ""
# ---------------------------------------------------------
# Defining Gradio layout
with gr.Blocks() as iface_chat:
gr.Markdown("""**Let's talk to AI in a chat!**""")
with gr.Row():
model = gr.Radio(
["stabilityai/StableBeluga2", "meta-llama/Llama-2-70b-chat-hf", "bigscience/bloomz"], value="stabilityai/StableBeluga2", label="Use model"
)
# Additional ending sequence, at which generation shoud stop
endseq = gr.CheckboxGroup(
["Human:", "AI:", "\\n", "</s>", "? (question mark)", ". (dot)"],
value=["Human:", "AI:", "</s>"],
label="Extra end sequences",
)
# Maximum length of inference session
max_length = gr.Radio(
[64, 128, 256, 512, 1024, 2048],
value=1024,
interactive=True,
label="Max length",
)
with gr.Row():
with gr.Column():
# Switch between sampling and greedy generation
do_sample = gr.Checkbox(value=True, interactive=True, label="do_sample")
context = gr.Textbox(
lines=3,
label="Initial context:",
interactive=True,
value="A Human talks to a powerful AI that follows "
"the Human's instructions.\n"
"AI is talkative, friendly, positive and provides "
"detailed answers to any question.</s>\n"
"Human: Hi!</s>\n"
"AI: How can I help you?",
)
# Only one of top_k and top_p can be set. Requires "do_sample=True" to work.
top_k = gr.Number(value=0, precision=0, interactive=True, label="top_k")
top_p = gr.Number(value=0.9, precision=2, interactive=True, label="top_p")
# TODO num_beams
# Generation temperature
temperature = gr.Number(
value=0.75, precision=2, interactive=True, label="Temperature"
)
chat = gr.Chatbot(label="Chat window")
prompt = gr.Textbox(
show_label=False, label="Prompt", placeholder="Prompt Here and press Enter..."
).style(container=False)
with gr.Row():
button_generate = gr.Button("Generate")
button_reset = gr.Button("Reset session")
with gr.Accordion("Raw prompt log", open=False):
output = gr.Textbox(lines=3, show_label=False).style(container=False)
# Chat history
state = gr.State(EMPTY_STATE)
# Define button actions
inputs = [
state,
prompt,
model,
context,
output,
endseq,
max_length,
do_sample,
top_k,
top_p,
temperature,
]
outputs = [state, chat, prompt, output]
prompt.submit(generate, inputs=inputs, outputs=outputs)
button_generate.click(generate, inputs=inputs, outputs=outputs)
button_reset.click(reset, inputs=[state], outputs=[state, chat, output])
examples = gr.Examples(
inputs=[context, prompt, model, do_sample, top_k, top_p, temperature],
examples=[
[
"Human talks to a powerful AI that follows the Human's instructions. "
"AI is a smart, talkative, friendly, honest, helpful, harmless assistant to Human. "
"AI has instant access to an online encyclopedia containing all the facts about the world "
"and answers any question in detail. AI never says common misconceptions, "
"outdated information, lies, fiction, myths, jokes, or memes.</s>\n"
"AI: Hi! How can I help you?</s>\n",
"Could you remind me please who was Neil Armstrong?",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
],
[
"Human mluví s mocnou, inteligentní a vševědoucí AI, která plní instrukce od Human. "
"AI je výřečná, přátelská, pozitivní a poskytuje detailní odpovědi na jakoukoliv otázku.</s>\n"
"Human: Ahoj!</s>\n"
"AI: Ahoj! Jak ti mohu pomoci?",
"Můžeš mi prosím připomenout, kdo byl Neil Armstrong?",
"stabilityai/StableBeluga2",
True,
0,
0.9,
0.75,
],
],
)
|